Ejemplo n.º 1
0
def main(**kargs):
    initial_weights_file, initial_i_frame = latest(kargs['weights_dir'])

    print("Continuing using weights from file: ", initial_weights_file, "from", initial_i_frame)

    if kargs['theano_verbose']:
        theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'
        theano.config.optimizer = 'fast_compile'

    ale = ag.init(display_screen=(kargs['visualize'] == 'ale'), record_dir=kargs['record_dir'])
    game = ag.SpaceInvadersGame(ale)


    def new_game():
        game.ale.reset_game()
        game.finished = False
        game.cum_reward = 0
        game.lives = 4
        return game

    replay_memory = dqn.ReplayMemory(size=kargs['dqn.replay_memory_size']) if not kargs['dqn.no_replay'] else None
    # dqn_algo = q.ConstAlgo([3])
    dqn_algo = dqn.DQNAlgo(game.n_actions(),
                           replay_memory=replay_memory,
                           initial_weights_file=initial_weights_file,
                           build_network=kargs['dqn.network'],
                           updates=kargs['dqn.updates'])

    dqn_algo.replay_start_size = kargs['dqn.replay_start_size']
    dqn_algo.final_epsilon = kargs['dqn.final_epsilon']
    dqn_algo.initial_epsilon = kargs['dqn.initial_epsilon']
    dqn_algo.i_frames = initial_i_frame

    dqn_algo.log_frequency=kargs['dqn.log_frequency']


    import Queue
    dqn_algo.mood_q = Queue.Queue() if kargs['show_mood'] else None

    if kargs['show_mood'] is not None:
        plot = kargs['show_mood']()

        def worker():
            while True:
                item = dqn_algo.mood_q.get()
                plot.show(item)
                dqn_algo.mood_q.task_done()

        import threading
        t = threading.Thread(target=worker)
        t.daemon = True
        t.start()

    print(str(dqn_algo))

    visualizer = ag.SpaceInvadersGameCombined2Visualizer() if kargs['visualize'] == 'q' else q.GameNoVisualizer()
    teacher = q.Teacher(new_game, dqn_algo, visualizer,
                        ag.Phi(skip_every=4), repeat_action=4, sleep_seconds=0)
    teacher.teach(500000)
Ejemplo n.º 2
0
    def __init__(self,
                 simulator,
                 policy_net,
                 target_net1,
                 target_net2,
                 memory_size,
                 online_memory_size=50000,
                 value_net_trainer=None,
                 state_is_image=False,
                 use_value_net=False):
        self.s = simulator
        self.policy_net = policy_net
        self.target_net1 = target_net1
        self.target_net2 = target_net2
        #        self.optimizer = optimizer
        self.optimizer = optim.Adam(self.policy_net.parameters(),
                                    lr=0.0001,
                                    weight_decay=7e-5)
        self.MEMORY_SIZE = memory_size
        self.memory = dqn.ReplayMemory(memory_size)
        self.memory_online = dqn.ReplayMemory(online_memory_size)
        self.value_net_trainer = value_net_trainer

        # set the hyperparameters
        self.BATCH_SIZE = 32
        self.GAMMA = 0.9
        self.IMG_CHANNEL = 3
        self.ACTION_BOUNDS = np.array([
            [-1.0, 1.0],  # vx
            [-1.0, 1.0],  # vy
            [-1.0, 1.0],  # vz
            #    [-1.0,1.0], # wx
            #    [-1.0,1.0], # wy
            #    [-1.0,1.0], # wz
            [-1.0, 1.0],  # terminate
            [0.0, 1.0],  # gripper position, close
            [0.0, 1.0]  # gripper position, open
        ])
        self.N_ACTIONS = self.ACTION_BOUNDS.shape[0]
        self.CEM_ITER = 2
        self.SHOULD_TRAIN_VALUE_NET = True
        self.STATE_IS_IMAGE = state_is_image
        self.USE_VALUE_NET = use_value_net

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
Ejemplo n.º 3
0
    return episode_reward / num_tests


# Driver code to run the model
# Offline
env.reset()
# Policy and Target networks
policy_net = dqn.DQN(dynamicsmodel.input_dim, dynamicsmodel.input_dim,
                     n_actions).to(device)
target_net = dqn.DQN(dynamicsmodel.input_dim, dynamicsmodel.input_dim,
                     n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
policy_net.train()
target_net.eval()
# Replay buffer
memory = dqn.ReplayMemory(10000)
dqn_loss_fn = nn.MSELoss()
dynaq = dynamicsmodel.DynaQ(64, n_actions).to(device)
# dynaq.eval()
# Optimizer
dqn_optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.01)
dyna_optimizer = torch.optim.Adam(dynaq.parameters(), lr=0.01)
algo_list = ["Dyna Q-Offline", "Dyna Q-Online"]
torch.nn.utils.clip_grad_norm(policy_net.parameters(), 0.5)
torch.nn.utils.clip_grad_norm(dynaq.parameters(), 0.5)
num_tests = 10

episode_rewards = []
avg_returns = []
# Offline
episode_rewards.append(train_offline())
Ejemplo n.º 4
0
def main(**kargs):
    initial_weights_file, i_total_action = latest(kargs['weights_dir'])

    print("Continuing using weights from file: ", initial_weights_file, "from",
          i_total_action)

    if kargs['theano_verbose']:
        theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'
        theano.config.optimizer = 'fast_compile'

    if kargs['game'] == 'simple_breakout':
        game = simple_breakout.SimpleBreakout()

        class P(object):
            def __init__(self):
                self.screen_size = 12

            def __call__(self, frames):
                return frames

        phi = P()
    else:
        ale = ag.init(game=kargs['game'],
                      display_screen=(kargs['visualize'] == 'ale'),
                      record_dir=kargs['record_dir'])
        game = ag.ALEGame(ale)
        phi = ag.Phi(method=kargs["phi_method"])

    replay_memory = dqn.ReplayMemory(size=kargs['dqn.replay_memory_size']
                                     ) if not kargs['dqn.no_replay'] else None
    algo = dqn.DQNAlgo(game.n_actions(),
                       replay_memory=replay_memory,
                       initial_weights_file=initial_weights_file,
                       build_network=kargs['dqn.network'],
                       updates=kargs['dqn.updates'],
                       screen_size=phi.screen_size)

    algo.replay_start_size = kargs['dqn.replay_start_size']
    algo.final_epsilon = kargs['dqn.final_epsilon']
    algo.initial_epsilon = kargs['dqn.initial_epsilon']
    algo.i_action = i_total_action

    algo.log_frequency = kargs['dqn.log_frequency']
    algo.target_network_update_frequency = kargs[
        'target_network_update_frequency']
    algo.final_exploration_frame = kargs['final_exploration_frame']

    import Queue
    algo.mood_q = Queue.Queue() if kargs['show_mood'] else None

    if kargs['show_mood'] is not None:
        plot = kargs['show_mood']()

        def worker():
            while True:
                item = algo.mood_q.get()
                plot.show(item)
                algo.mood_q.task_done()

        import threading
        t = threading.Thread(target=worker)
        t.daemon = True
        t.start()

    print(str(algo))

    if kargs['visualize'] != 'q':
        visualizer = q.GameNoVisualizer()
    else:
        if kargs['game'] == 'simple_breakout':
            visualizer = simple_breakout.SimpleBreakoutVisualizer(algo)
        else:
            visualizer = ag.ALEGameVisualizer(phi.screen_size)

    teacher = q.Teacher(
        game=game,
        algo=algo,
        game_visualizer=visualizer,
        phi=phi,
        repeat_action=kargs['repeat_action'],
        i_total_action=i_total_action,
        total_n_actions=50000000,
        max_actions_per_game=10000,
        skip_n_frames_after_lol=kargs['skip_n_frames_after_lol'],
        run_test_every_n=kargs['run_test_every_n'])
    teacher.teach()
Ejemplo n.º 5
0
    def collect_data(self,
                     use_scripted_policy=True,
                     visualize=False,
                     n_files=None,
                     start_at=None,
                     epsilon=0.0,
                     dt=50e-3,
                     maxtime=20,
                     dryrun=False,
                     memory_capacity=5000):

        t = 0

        if torch.cuda.is_available():
            print("cuda is available :D")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        USE_VALUE_NET = False
        STATE_IS_IMAGE = True
        MEMORY_SIZE = memory_capacity
        GZIP_COMPRESSION_LEVEL = 3
        self.s.set_visualize(visualize)

        policy_net = None
        target_net1 = None
        target_net2 = None
        value_net = None
        value_net_trainer = None

        # create the target and policy networks
        policy_net = dqn.DQN().to(device)
        # default xavier init
        for m in policy_net.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform(m.weight,
                                       gain=nn.init.calculate_gain('relu'))

        if os.path.isfile("vrep_arm_model.pt"):
            policy_net.load_state_dict(torch.load('vrep_arm_model.pt'))
            print("loaded existing model file")

        target_net1 = dqn.DQN().to(device)
        target_net2 = dqn.DQN().to(device)
        value_net = valuenet.ValueNet().to(device)
        value_net_trainer = valuenet.ValueNetTrainer(value_net)
        print(
            "number of parameters: ",
            sum(p.numel() for p in policy_net.parameters() if p.requires_grad))
        target_net1.load_state_dict(policy_net.state_dict())
        target_net1.eval()
        target_net2.load_state_dict(policy_net.state_dict())
        target_net2.eval()

        br = brain.Brain(
            simulator=self.s,  #only to access scripted policy
            policy_net=policy_net,
            target_net1=target_net1,
            target_net2=target_net2,
            memory_size=MEMORY_SIZE,
            value_net_trainer=value_net_trainer,
            state_is_image=STATE_IS_IMAGE,
            use_value_net=USE_VALUE_NET)

        # ============================================================================
        # train for num_episodes epochs
        # ============================================================================
        from itertools import count

        total_reached = 0
        reached = 0
        MAX_TIME = maxtime
        FRAME_SKIP = 1
        MAX_FILES = 6
        FILE_PREFIX = "dataset/replay_"
        num_file = 0
        PRINT_EVERY = 1

        br.memory = dqn.ReplayMemory(capacity=MEMORY_SIZE)

        if not use_scripted_policy:
            FILE_PREFIX = "dataset_online/replay_"
            MAX_FILES = int(MAX_FILES /
                            2)  # used to be divided by 2, but nvm for now
        num_episodes = 2000000

        if n_files is not None:
            MAX_FILES = n_files

        if start_at is not None:
            num_file += start_at
            MAX_FILES += start_at

        start_time = time.time()
        total_episode_reward = 0

        for i_episode in range(num_episodes):
            episode_reward = 0
            if i_episode % PRINT_EVERY == 0:
                print("recording: episode", i_episode)
            # Initialize the environment and state
            self.s.reset()
            target_x, target_y, target_z = self.s.randomly_place_target()
            img_state, numerical_state = self.s.get_robot_state()
            error = np.random.normal(0, 0.1)
            error = 0

            for t in count():
                # Select and perform an action based on epsilon greedy
                # action is chosen based on the policy network
                img_state = torch.Tensor(img_state)
                numerical_state = torch.Tensor(numerical_state)

                # get the reward, detect if the task is done
                a = [0, 0, 0]
                action = None
                last_img_state = img_state
                last_numerical_state = numerical_state
                # record the action from the scripted exploration
                thresh = None
                action = None
                if use_scripted_policy:
                    action = br.select_action_scripted_exploration(thresh=1.0,
                                                                   error=error)
                else:
                    action = br.select_action_epsilon_greedy(
                        img_state, numerical_state, epsilon)

                self.s.set_control(action.view(-1).cpu().numpy())
                self.s.step()
                img_state, numerical_state = self.s.get_robot_state()
                reward_number, done = self.s.get_reward_and_done(
                    numerical_state)
                reward = torch.tensor([reward_number], device=device)

                episode_reward += (br.GAMMA**t) * reward_number

                if done and reward_number > 0:
                    #reached the target on its own
                    #                print("data collector: episode reached at timestep",t)
                    reached += 1

                if t > MAX_TIME:
                    # we will terminate if it doesn't finish
                    #                print("data collector: episode timeout")
                    done = True

                # Observe new state
                if not done:
                    state_img_tensor = torch.Tensor(img_state)
                    state_numerical_tensor = torch.Tensor(numerical_state)
                else:
                    state_img_tensor = None
                    state_numerical_tensor = None

                # Store the transition in memory
                # as the states are ndarray, change it to tensor
                # the actoin and rewards are already tensors, so they're cool
                if (t % FRAME_SKIP == 0):
                    br.memory.push(torch.Tensor(last_img_state),
                                   torch.Tensor(last_numerical_state),
                                   action.view(-1).float(), state_img_tensor,
                                   state_numerical_tensor, reward)

                if done:
                    #visualize and break
                    break

            total_episode_reward += episode_reward

            if i_episode % 10 == 0:
                time_per_ep = (time.time() - start_time) / 10.0
                start_time = time.time()
                print("reached target", reached, "/ 10 times, memory:",
                      len(br.memory), "/", MEMORY_SIZE, ",",
                      (100.0 * len(br.memory) / MEMORY_SIZE), "% full,",
                      time_per_ep, "sec/ep")
                total_reached += reached
                reached = 0

            if len(br.memory) >= br.memory.capacity:
                # if the buffer is full, save it and reset it
                filename = FILE_PREFIX + str(num_file).zfill(2) + ".gz"
                if not dryrun:
                    print("> saving file into", filename)
                    with gzip.GzipFile(
                            filename, 'wb',
                            compresslevel=GZIP_COMPRESSION_LEVEL) as handle:
                        cPickle.dump(br.memory,
                                     handle,
                                     protocol=cPickle.HIGHEST_PROTOCOL)
                    print("> saving completed")
                else:
                    print("> data collector: dryrun, not saving memory")
                num_file += 1
                if (num_file >= MAX_FILES):
                    print("> data_collector: all files collected, closing")
                    print("> total success rate:",
                          (total_reached * 1.0 / i_episode))
                    print("> mean episode reward:",
                          (total_episode_reward * 1.0 / i_episode))
                    return total_reached * 1.0 / i_episode, total_episode_reward * 1.0 / i_episode

                br.memory = dqn.ReplayMemory(capacity=MEMORY_SIZE)
                print("> data_collector: memory is full, saved as: " +
                      filename)
Ejemplo n.º 6
0
        if success_rate is not None:
            timesteps.append(i)
            success_rates.append(success_rate)
            mean_rewards.append(mean_reward)
            plot_graph(timesteps, success_rates, "successrate")
            plot_graph(timesteps, mean_rewards, "mean_reward")

        if success_rate > 0.5 and not is_finetuning:
            print("==================================================")
            print("success rate over 0.5, reconfiguring online memory")
            print("==================================================")
            #switch to online training
            #remove all old memories and use only one file now
            reconfigure_for_finetuning()

    # optimize the model
    loss = br.optimize_model()
    losses += loss

    if i % 6000 == 0:  #(TARGET_UPDATE*FRAME_SKIP) == 0:
        br.update_target_net()

print("training complete, saving model")
torch.save(policy_net.state_dict(), MODEL_NAME + ".pt")
print("model saving completed")

print("end of offline training, beginning online training")
br.memory = dqn.ReplayMemory(br.MEMORY_SIZE)
br.memory_online = dqn.ReplayMemory(br.MEMORY_SIZE)
Ejemplo n.º 7
0
def main(game_name, network_type, updates_method,
         target_network_update_frequency, initial_epsilon, final_epsilon,
         test_epsilon, final_exploration_frame, replay_start_size,
         deepmind_rmsprop_epsilon, deepmind_rmsprop_learning_rate,
         deepmind_rmsprop_rho, rmsprop_epsilon, rmsprop_learning_rate,
         rmsprop_rho, phi_type, phi_method, epoch_size, n_training_epochs,
         n_test_epochs, visualize, record_dir, show_mood, replay_memory_size,
         no_replay, repeat_action, skip_n_frames_after_lol,
         max_actions_per_game, weights_dir, algo_initial_state_file,
         log_frequency, theano_verbose):
    args = locals()

    if theano_verbose:
        theano.config.compute_test_value = 'warn'
        theano.config.exception_verbosity = 'high'
        theano.config.optimizer = 'fast_compile'

    if game_name == 'simple_breakout':
        game = simple_breakout.SimpleBreakout()

        class P(object):
            def __init__(self):
                self.screen_size = (12, 12)

            def __call__(self, frames):
                return frames

        phi = P()
    else:
        ale = ag.init(game=game_name,
                      display_screen=(visualize == 'ale'),
                      record_dir=record_dir)
        game = ag.ALEGame(ale)
        if phi_type == '4':
            phi = ag.Phi4(method=phi_method)
        elif phi_type == '1':
            phi = ag.Phi(method=phi_method)
        else:
            raise RuntimeError("Unknown phi: {phi}".format(phi=phi_type))

    if network_type == 'nature':
        build_network = network.build_nature
    elif network_type == 'nature_with_pad':
        build_network = network.build_nature_with_pad
    elif network_type == 'nips':
        build_network = network.build_nips
    elif network_type == 'nature_with_pad_he':
        build_network = network.build_nature_with_pad_he
    elif hasattr(network_type, '__call__'):
        build_network = network_type
    else:
        raise RuntimeError(
            "Unknown network: {network}".format(network=network_type))

    if updates_method == 'deepmind_rmsprop':
        updates = \
            lambda loss, params: u.deepmind_rmsprop(loss, params,
                                                          learning_rate=deepmind_rmsprop_learning_rate,
                                                          rho=deepmind_rmsprop_rho,
                                                          epsilon=deepmind_rmsprop_epsilon)
    elif updates_method == 'rmsprop':
        updates = \
            lambda loss, params: lasagne.updates.rmsprop(loss, params,
                                                         learning_rate=rmsprop_learning_rate,
                                                         rho=rmsprop_rho,
                                                         epsilon=rmsprop_epsilon)
    else:
        raise RuntimeError(
            "Unknown updates: {updates}".format(updates=updates_method))

    replay_memory = dqn.ReplayMemory(
        size=replay_memory_size) if not no_replay else None

    def create_algo():
        algo = dqn.DQNAlgo(game.n_actions(),
                           replay_memory=replay_memory,
                           build_network=build_network,
                           updates=updates,
                           screen_size=phi.screen_size)

        algo.replay_start_size = replay_start_size
        algo.final_epsilon = final_epsilon
        algo.initial_epsilon = initial_epsilon

        algo.log_frequency = log_frequency
        algo.target_network_update_frequency = target_network_update_frequency
        algo.final_exploration_frame = final_exploration_frame
        return algo

    algo_train = create_algo()
    algo_test = create_algo()
    algo_test.final_epsilon = test_epsilon
    algo_test.initial_epsilon = test_epsilon
    algo_test.epsilon = test_epsilon

    import Queue
    algo_train.mood_q = Queue.Queue() if show_mood else None

    if show_mood is not None:
        import Queue
        algo_train.mood_q = Queue.Queue()
        if show_mood == 'plot':
            plot = Plot()
        elif show_mood == "log":
            plot = Log()

        def worker():
            while True:
                item = algo_train.mood_q.get()
                plot.show(item)
                algo_train.mood_q.task_done()

        import threading
        t = threading.Thread(target=worker)
        t.daemon = True
        t.start()

    print(str(algo_train))

    if visualize != 'q':
        visualizer = q.GameNoVisualizer()
    else:
        if game_name == 'simple_breakout':
            visualizer = simple_breakout.SimpleBreakoutVisualizer(algo_train)
        else:
            visualizer = ag.ALEGameVisualizer(phi.screen_size)

    teacher = q.Teacher(game=game,
                        algo=algo_train,
                        game_visualizer=visualizer,
                        phi=phi,
                        repeat_action=repeat_action,
                        max_actions_per_game=max_actions_per_game,
                        skip_n_frames_after_lol=skip_n_frames_after_lol,
                        tester=False)

    tester = q.Teacher(game=game,
                       algo=algo_test,
                       game_visualizer=visualizer,
                       phi=phi,
                       repeat_action=repeat_action,
                       max_actions_per_game=max_actions_per_game,
                       skip_n_frames_after_lol=skip_n_frames_after_lol,
                       tester=True)

    q.teach_and_test(teacher,
                     tester,
                     n_epochs=n_training_epochs,
                     frames_to_test_on=n_test_epochs * epoch_size,
                     epoch_size=epoch_size,
                     state_dir=weights_dir,
                     algo_initial_state_file=algo_initial_state_file)
Ejemplo n.º 8
0
def train(White_policy_net, White_target_net, Black_policy_net,
          Black_target_net):
    whiteWins = 0
    blackWins = 0
    drawByNoProgress = 0
    drawByTooLongGame = 0
    drawByStaleMate = 0
    move_count = 0
    global White_loss
    global Black_loss
    global em

    for episode in range(past_episodes + 1, num_episodes + past_episodes + 1):
        print("Episode number: " + str(episode))
        print("Exploration Rate: " + agent.tell_me_exploration_rate(episode))
        terminal = False
        em.reset()  #reset the environment to start all over again
        White_tempMemory = dqn.ReplayMemory(
            per_game_memory_size)  #Create tempMemory for one match
        Black_tempMemory = dqn.ReplayMemory(
            per_game_memory_size)  #Create tempMemory for one match

        #Calculating available actions for just once, to initiate sequence
        em.calculate_available_actions("white")

        while True:
            state = em.get_state(
            )  #get the BitVectorBoard state from the environment as a tensor

            #If the game didn't end with the last move, now it's white's turn to move
            if not terminal:
                em, action = mcts.initializeTree(em, "white", move_time,
                                                 episode, White_policy_net,
                                                 agent,
                                                 device)  #white makes his move
                next_state = em.get_state()

                #We don't know what the reward will be until the game ends. So put 0 for now.
                state = state.unsqueeze(0)
                next_state = next_state.unsqueeze(0)
                gem = deepcopy(em)
                next_state_av_acts = gem.calculate_available_actions("black")
                White_tempMemory.push(
                    dqn.Experience(state, action, next_state,
                                   next_state_av_acts, 0, False))
                state = next_state

            #Check if game ends
            terminal, whiteWins, blackWins, drawByNoProgress, drawByTooLongGame, drawByStaleMate, White_tempMemory = hf.check_game_termination(
                em, "black", terminal, whiteWins, blackWins, drawByNoProgress,
                drawByTooLongGame, drawByStaleMate, White_tempMemory)
            #Check if game ends by no progress rule
            terminal, whiteWins, blackWins, drawByNoProgress, drawByTooLongGame, drawByStaleMate, White_tempMemory = hf.check_game_termination(
                em, "black", terminal, whiteWins, blackWins, drawByNoProgress,
                drawByTooLongGame, drawByStaleMate, White_tempMemory, True)

            #If the game didn't end with the last move, now it's black's turn to move
            if not terminal:
                em, action = mcts.initializeTree(em, "black", move_time,
                                                 episode, Black_policy_net,
                                                 agent,
                                                 device)  #white makes his move
                next_state = em.get_state()

                gem = deepcopy(em)
                next_state_av_acts = gem.calculate_available_actions("white")
                #We don't know what the reward will be until the game ends. So put 0 for now.
                next_state = next_state.unsqueeze(0)
                Black_tempMemory.push(
                    dqn.Experience(state, action, next_state,
                                   next_state_av_acts, 0, False))

            #Check if game ends
            terminal, whiteWins, blackWins, drawByNoProgress, drawByTooLongGame, drawByStaleMate, Black_tempMemory = hf.check_game_termination(
                em, "white", terminal, whiteWins, blackWins, drawByNoProgress,
                drawByTooLongGame, drawByStaleMate, Black_tempMemory)
            #Check if game ends by no progress rule
            terminal, whiteWins, blackWins, drawByNoProgress, drawByTooLongGame, drawByStaleMate, Black_tempMemory = hf.check_game_termination(
                em, "white", terminal, whiteWins, blackWins, drawByNoProgress,
                drawByTooLongGame, drawByStaleMate, Black_tempMemory, True)

            #-----------------------------------------------------FOR WHITE----------------------------------------
            #Returns true if length of the memory is greater than or equal to batch_size
            if White_memory.can_provide_sample(batch_size):
                White_experiences = White_memory.sample(
                    batch_size)  #sample experiences from memory
                Wstates, Wactions, Wrewards, Wnext_states, Wnext_state_av_actions = hf.extract_tensors(
                    White_experiences)  #extract them
                Wstates = Wstates.to(device)
                Wactions = Wactions.to(device)
                Wrewards = Wrewards.to(device)
                Wnext_states = Wnext_states.to(device)

                #get the current q values to calculate loss afterwards
                #Shape of policy_net(states): [batchsize, 282]
                #Shape of actions: [batchsize] (Tek rowluk bir tensor)
                #Shape of actions.unsq(-1): [batchsize, 1] (Bir üstteki rowdakileri her row'a birer tane olacak şekilde rowlara ayır)
                #Shape of current_q_values: [batchsize, 1] (Her rowdan en büyük q-value'yu seçtik.)
                Wcurrent_q_values = White_policy_net(Wstates).gather(
                    dim=1, index=Wactions.unsqueeze(-1)).to(device)
                #Shape of next_q_values:	[batchsize,282]
                Wnext_q_values = White_target_net(Wnext_states).detach().to(
                    device)
                Wnext_state_maxq = []

                #batch_corrector_start = timeit.default_timer()
                #Getting correct q-values from next_state. To do this, we have to compare against available actions
                for i in range(batch_size):
                    q_values = Wnext_q_values[i]
                    q_values = q_values.to(device)
                    available_actions = Wnext_state_av_actions[i]
                    if len(available_actions) == 0:
                        Wnext_state_maxq.append(
                            torch.tensor(-1000000, dtype=torch.float32))
                        continue
                    indices = torch.topk(q_values,
                                         len(q_values))[1].to(device).detach()

                    for j in range(len(indices)):
                        max_index = indices[j].to(device).detach()
                        #If illegal move is given as output by the model, punish that action and make it select an action again.
                        if max_index in available_actions:
                            break
                    Wnext_state_maxq.append(q_values[max_index])
                #batch_corrector_end = timeit.default_timer()
                #print("Batch Corrector Time: " + str(batch_corrector_end - batch_corrector_start))

                # set y_j to r_j for terminal state, otherwise to r_j + gamma*max(Q). Garbage q-values for terminal states are not used.
                target_q_values = torch.cat(
                    tuple(Wrewards[i].unsqueeze(0) if White_experiences[i][5]
                          else Wrewards[i].unsqueeze(0) +
                          gamma * Wnext_state_maxq[i].unsqueeze(0)
                          for i in range(batch_size)))
                target_q_values = target_q_values.to(device)
                #Shape of target_q_values: [batchsize, 1]

                #clear the old gradients. we only focus on this batch. pytorch accumulates gradients in default.
                White_optimizer.zero_grad()
                # returns a new Tensor, detached from the current graph, the result will never require gradient
                target_q_values = target_q_values.detach()

                White_loss = F.mse_loss(Wcurrent_q_values, target_q_values)
                White_loss.backward()
                White_optimizer.step()  #take a step based on the gradients

    #-----------------------------------------------------FOR BLACK----------------------------------------
    #Returns true if length of the memory is greater than or equal to batch_size
            if Black_memory.can_provide_sample(batch_size):
                Black_experiences = Black_memory.sample(
                    batch_size)  #sample experiences from memory
                Bstates, Bactions, Brewards, Bnext_states, Bnext_state_av_actions = hf.extract_tensors(
                    Black_experiences)  #extract them
                Bstates = Bstates.to(device)
                Bactions = Bactions.to(device)
                Brewards = Brewards.to(device)
                Bnext_states = Bnext_states.to(device)

                #get the current q values to calculate loss afterwards
                #Shape of policy_net(states): [batchsize, 282]
                #Shape of actions: [batchsize] (Tek rowluk bir tensor)
                #Shape of actions.unsq(-1): [batchsize, 1] (Bir üstteki rowdakileri her row'a birer tane olacak şekilde rowlara ayır)
                #Shape of current_q_values: [batchsize, 1] (Her rowdan en büyük q-value'yu seçtik.)
                Bcurrent_q_values = Black_policy_net(Bstates).gather(
                    dim=1, index=Bactions.unsqueeze(-1)).to(device)
                #Shape of next_q_values:	[batchsize,282]
                Bnext_q_values = Black_target_net(Bnext_states).detach().to(
                    device)
                Bnext_state_maxq = []

                #batch_corrector_start = timeit.default_timer()
                #Getting correct q-values from next_state. To do this, we have to compare against available actions
                for i in range(batch_size):
                    q_values = Bnext_q_values[i]
                    q_values = q_values.to(device)
                    available_actions = Bnext_state_av_actions[i]
                    if len(available_actions) == 0:
                        Bnext_state_maxq.append(
                            torch.tensor(-1000000, dtype=torch.float32))
                        continue
                    indices = torch.topk(q_values,
                                         len(q_values))[1].to(device).detach()

                    for j in range(len(indices)):
                        max_index = indices[j].to(device).detach()
                        #If illegal move is given as output by the model, punish that action and make it select an action again.
                        if max_index in available_actions:
                            break
                    Bnext_state_maxq.append(q_values[max_index])
                #batch_corrector_end = timeit.default_timer()
                #print("Batch Corrector Time: " + str(batch_corrector_end - batch_corrector_start))

                # set y_j to r_j for terminal state, otherwise to r_j + gamma*max(Q). Garbage q-values for terminal states are not used.
                target_q_values = torch.cat(
                    tuple(Brewards[i].unsqueeze(0) if Black_experiences[i][5]
                          else Brewards[i].unsqueeze(0) +
                          gamma * Bnext_state_maxq[i].unsqueeze(0)
                          for i in range(batch_size)))
                target_q_values = target_q_values.to(device)
                #Shape of target_q_values: [batchsize, 1]

                #clear the old gradients. we only focus on this batch. pytorch accumulates gradients in default.
                Black_optimizer.zero_grad()
                # returns a new Tensor, detached from the current graph, the result will never require gradient
                target_q_values = target_q_values.detach()

                Black_loss = F.mse_loss(Bcurrent_q_values, target_q_values)
                Black_loss.backward()
                Black_optimizer.step()  #take a step based on the gradients
                #---------------------------------------------------------------------

            #If we're in a terminal state, we never step in the terminal state. We end the episode instead.
            #Record the step number.'''
            if terminal:
                print(
                    str(White_tempMemory.push_count +
                        Black_tempMemory.push_count) +
                    " moves played in this match.\n")
                move_count += White_tempMemory.push_count
                move_count += White_tempMemory.push_count
                #Editing the last memory, so that its terminal value is True
                White_tempMemory.memory[-1] = White_tempMemory.memory[
                    -1]._replace(terminal=True)
                Black_tempMemory.memory[-1] = Black_tempMemory.memory[
                    -1]._replace(terminal=True)

                #Moving full tuples to big memory, and deleting the temp memory
                White_memory.pushBlock(White_tempMemory)
                Black_memory.pushBlock(Black_tempMemory)
                del White_tempMemory
                del Black_tempMemory
                break

        #Update target network with weights and biases in the policy network
        #Also create new model files
        if episode % target_update == 0:
            White_target_net.load_state_dict(White_policy_net.state_dict())
            Black_target_net.load_state_dict(Black_policy_net.state_dict())
            torch.save(
                {
                    'episode': episode,
                    'White_model_state_dict': White_policy_net.state_dict(),
                    'White_optimizer_state_dict': White_optimizer.state_dict(),
                    'White_loss': White_loss,
                    'Black_model_state_dict': Black_policy_net.state_dict(),
                    'Black_optimizer_state_dict': Black_optimizer.state_dict(),
                    'Black_loss': Black_loss,
                }, PATH_TO_DIRECTORY + "MiniChess-trained-model" +
                str(episode) + ".tar")
            print("Episode:" + str(episode) + " -------Weights are updated!")

    print("White Wins: " + str(whiteWins) + "\t\t\tWin Rate: %" +
          str(100 * whiteWins / num_episodes))
    print("Black Wins: " + str(blackWins) + "\t\t\tWin Rate: %" +
          str(100 * blackWins / num_episodes))
    print("Draw By No Progress: " + str(drawByNoProgress) +
          "\t\tNo Progress Rate: %" +
          str(100 * drawByNoProgress / num_episodes))
    print("Draw By Too Long Game: " + str(drawByTooLongGame) +
          "\tToo Long Game Rate: %" +
          str(100 * drawByTooLongGame / num_episodes))
    print("Draw By Stalemate: " + str(drawByStaleMate) +
          "\t\tStalemate Rate: %" + str(100 * drawByStaleMate / num_episodes))
    print("Average Move per game: " + str(move_count / num_episodes))
    return None
Ejemplo n.º 9
0
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #Initialize policy nets and environment.
    White_policy_net = dqn.DQN().to(device)
    Black_policy_net = dqn.DQN().to(device)
    em = minichess.MiniChess(device)
    last_trained_model = fileoperations.find_last_edited_file(
        PATH_TO_DIRECTORY
    )  #Returns the last_trained model from multiple models.
    num_episodes = int(sys.argv[2]) if len(sys.argv) > 2 else 100

    if (sys.argv[1]) == "train":
        strategy = dqn.EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
        agent = dqn.Agent(strategy, device)
        White_memory = dqn.ReplayMemory(memory_size)
        Black_memory = dqn.ReplayMemory(memory_size)
        White_target_net = dqn.DQN().to(device)
        Black_target_net = dqn.DQN().to(device)
        past_episodes = 0  #how many episode is played before? this variable may be changed in the upcoming if block.
        White_loss = 0
        Black_loss = 0
        lr = 0.2
        White_optimizer = optim.Adam(params=White_policy_net.parameters(),
                                     lr=lr)
        Black_optimizer = optim.Adam(params=Black_policy_net.parameters(),
                                     lr=lr)

        if last_trained_model is not None:
            print("***Last trained model: " + last_trained_model)
            checkpoint = torch.load(last_trained_model, map_location=device)
Ejemplo n.º 10
0
            # Store the transition in memory
            # as the states are ndarray, change it to tensor
            # the actoin and rewards are already tensors, so they're cool
            if (t % self.FRAME_SKIP == 0):
                self.br.memory.push(torch.Tensor(last_img_state),
                                    torch.Tensor(last_numerical_state),
                                    action.view(-1).float(), state_img_tensor,
                                    state_numerical_tensor, reward)

            if done:
                break

        return reached, episode_reward


if __name__ == "__main__":
    # Create new threads
    memory_buffer = dqn.ReplayMemory(1000000)
    thread1 = DataCollectionThread("DCThread",
                                   memory_buffer,
                                   maxtime=20,
                                   dt=0.05,
                                   port=19991)
    # Start new Threads
    thread1.daemon = True  # daemon thread exits with the main thread, which is good
    thread1.start()
    while (True):
        #print(memory_buffer.memory)
        time.sleep(0.5)
    print "Exiting Main Thread"