_, _, screen_height, screen_width = init_screen.shape # Get number of actions from gym action space n_actions = 4 policy_net = DQN(screen_height, screen_width, n_actions,layers=20).to(device) target_net = DQN(screen_height, screen_width, n_actions,layers=20).to(device) PATH = 'C:/Users/sagau/Google Drive/smaller1.pth' optimizer = optim.Adam(policy_net.parameters(),lr=1e-4) load_mode = False if load_mode: model_dict = torch.load(PATH,map_location=torch.device('cpu')) i_episode = model_dict['epoch'] optimizer.load_state_dict(model_dict['optimizer']) policy_net.load_state_dict(model_dict['state_dict']) target_net.load_state_dict(model_dict['state_dict']) episode_durations = model_dict['episode_durations'] total_reward_list = model_dict['total_reward_list'] point_list = model_dict['point_list'] plot_durations() else: i_episode = 0 target_net.load_state_dict(policy_net.state_dict()) target_net.eval() episode_durations = [] total_reward_list = [] point_list = [] memory = ReplayMemory(10000000) steps_done = 0
###################################################################### init_screen = get_screen() _, _, screen_height, screen_width = init_screen.shape # Get number of actions from gym action space n_actions = 4 #PATH = 'C:/Users/sagau/Desktop/Kaggle/TetrisRepo/models/model1_2.pth' policy_net = DQN(screen_height, screen_width, n_actions,layers=20) policy_net.eval() policy_net = DQN(screen_height, screen_width, n_actions,layers=20).to(device) PATH = 'C:/Users/sagau/Google Drive/transfersmaller.pth' model_dict = torch.load(PATH,map_location=torch.device('cpu')) policy_net.load_state_dict(model_dict['state_dict']) ###################################################################### # Play with model ! sleep_time = 0.2 game = Tetris(nb_rows=8,nb_cols=6) done = game.generate_block(choice=random.randint(0,3)) rows = 0 for t in count(): # for t in range(200): state = get_screen() action = select_action(state,0) # print(action.item()) game.play_active_block(action.item()) if game.block_reached_end(): rows = rows + game.clear_rows()