示例#1
0
    def __init__(self, game_name, config, global_model_data_path, local_test_flag):
        if not local_test_flag:
            mkdirs(global_model_data_path+config.DRL.Learn.data_save_path)
        self.game_name = game_name
        self.data_save_path = global_model_data_path+config.DRL.Learn.data_save_path
        self.config = config
        self.global_iter = 0
        self.ckpt_dir = global_model_data_path+self.config.DRL.Learn.ckpt_dir
        self.ckpt_save_iter = self.config.DRL.Learn.ckpt_save_iter
        if not local_test_flag:
            mkdirs(self.ckpt_dir)

        self.apply_prioritize_memory = False
        if self.apply_prioritize_memory:
            self.memory = PrioritizedReplay(capacity=self.config.DRL.Learn.replay_memory_size)
        else:
            # store the previous observations in replay memory
            self.memory = deque()

        if self.game_name == 'flappybird':  # FlappyBird applies pytorch
            use_cuda = config.DRL.Learn.cuda and torch.cuda.is_available()
            self.device = 'cuda' if use_cuda else 'cpu'
            self.actions_number = self.config.DRL.Learn.actions
            # self.nn = DeepQNetwork().to(self.device)
            self.nn = FlappyBirdDQN().to(self.device)
            self.optim = optim.Adam(self.nn.parameters(), lr=self.config.DRL.Learn.learning_rate)
            if config.DRL.Learn.ckpt_load:
                self.load_checkpoint(model_name='flappy_bird_model')
            self.game_state = FlappyBird()
            if torch.cuda.is_available():
                torch.cuda.manual_seed(123)
            else:
                torch.manual_seed(123)
        elif self.game_name == 'Assault-v0' or self.game_name == 'Breakout-v0' or self.game_name == 'SpaceInvaders-v0':
            if self.game_name == 'Assault-v0':
                game_model_name = 'Assault-v0.tfmodel'
            elif self.game_name == 'Breakout-v0':
                game_model_name = 'Breakout-v0.npz'
            elif self.game_name == 'SpaceInvaders-v0':
                game_model_name = 'SpaceInvaders-v0.tfmodel'
            else:
                raise ValueError('Unknown game name {0}'.format(self.game_name))

            self.env = self.get_player_atari(train=False)
            num_actions = self.env .action_space.n
            self.nn = OfflinePredictor(PredictConfig(
                # model=Model(),
                model=Model(num_actions=num_actions, image_size=(84, 84)),
                session_init=SmartInit(self.ckpt_dir+game_model_name),
                input_names=['state'],
                output_names=['policy', 'pred_value']))

            self.config.DRL.Learn.actions = num_actions
示例#2
0
def test_only(args):
    from imageio import imsave
    data_folder = args.get("data_folder")
    test_ckpt = args.get("test_ckpt")
    test_folder = args.get("test_folder")
    if not os.path.exists(test_folder):
        os.makedirs(test_folder)
    image_size = 224
    pred_config = PredictConfig(
        model=ProgressiveSynTex(args),
        session_init=SmartInit(test_ckpt),
        input_names=["pre_image_input", "image_target"],
        output_names=['stages-target/viz', 'loss_output']
    )
    predictor = OfflinePredictor(pred_config)
    test_ds = get_data(data_folder, image_size, isTrain=False)
    test_ds.reset_state()
    idx = 0
    losses = list()
    print("------------------ predict --------------")
    for pii, it in test_ds:
        output_array, loss_output = predictor(pii, it)
        if output_array.ndim == 4:
            for i in range(output_array.shape[0]):
                imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)), output_array[i])
                idx += 1
        else:
            imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)), output_array)
            idx += 1
        losses.append(loss_output)
        print("loss #", idx, "=", loss_output)
    print("Test and save", idx, "images to", test_folder, "avg loss =", np.mean(losses))
示例#3
0
def init(args=None, is_running=0, pt=None):
    global ckpt2
    # 网络
    model = Net2()
    if is_running == 1:
        if pt == "":
            ckpt2 = tf.train.latest_checkpoint(logdir2)
        else:
            ckpt2 = '{}/{}'.format(logdir2, pt)
    else:
        ckpt2 = '{}/{}'.format(
            logdir2,
            args.ckpt) if args.ckpt else tf.train.latest_checkpoint(logdir2)
    session_inits = []
    if ckpt2:
        session_inits.append(SaverRestore(ckpt2))
    pred_conf = PredictConfig(
        model=model,
        input_names=['x_ppgs', 'x_mfccs', 'y_spec', 'y_mel'],
        output_names=['pred_spec', "ppgs"],
        session_init=ChainInit(session_inits))
    global predictor
    predictor = OfflinePredictor(pred_conf)
    if is_running == 1:
        return jsonify({"code": 0, "ckpt": ckpt2})
    def thread_function(self):
        """
        Run on secondary thread
        """
        pred = OfflinePredictor(
            PredictConfig(model=Model(IMAGE_SIZE, FRAME_HISTORY, self.METHOD,
                                      self.NUM_ACTIONS, GAMMA, ""),
                          session_init=get_model_loader(self.fname_model.name),
                          input_names=['state'],
                          output_names=['Qvalue']))

        # demo pretrained model one episode at a time
        if self.task_value == 'Play':
            play_n_episodes(get_player(files_list=self.selected_list,
                                       viz=0.01,
                                       data_type=self.window.usecase,
                                       saveGif=self.GIF_value,
                                       saveVideo=self.video_value,
                                       task='play'),
                            pred,
                            self.num_files,
                            viewer=self.window)
        # run episodes in parallel and evaluate pretrained model
        elif self.task_value == 'Evaluation':
            play_n_episodes(get_player(files_list=self.selected_list,
                                       viz=0.01,
                                       data_type=self.window.usecase,
                                       saveGif=self.GIF_value,
                                       saveVideo=self.video_value,
                                       task='eval'),
                            pred,
                            self.num_files,
                            viewer=self.window)
 def _make_pred_func(self, load):
     from train import ResNetFPNTrackModel
     pred_model = ResNetFPNTrackModel()
     predcfg = PredictConfig(
         model=pred_model,
         session_init=get_model_loader(load),
         input_names=pred_model.get_inference_tensor_names()[0],
         output_names=pred_model.get_inference_tensor_names()[1])
     return OfflinePredictor(predcfg)
 def load_model(self):
     print('Loading Model...')
     model_path = self.model_path
     model_constructor = self.get_model()
     pred_config = PredictConfig(model=model_constructor(
         self.nr_types, self.input_shape, self.mask_shape, self.input_norm),
                                 session_init=get_model_loader(model_path),
                                 input_names=self.input_tensor_names,
                                 output_names=self.output_tensor_names)
     self.predictor = OfflinePredictor(pred_config)
示例#7
0
    def __init__(self, name, need_network=True, need_img=True, model="best"):
        super().__init__(name=name, is_deterministic=True)
        self._resizer = CustomResize(cfg.PREPROC.TEST_SHORT_EDGE_SIZE,
                                     cfg.PREPROC.MAX_SIZE)
        self._prev_box = None
        self._ff_gt_feats = None
        self._need_network = need_network
        self._need_img = need_img
        self._rotated_bbox = None

        if need_network:
            logger.set_logger_dir(
                "/tmp/test_log_/" + str(random.randint(0, 10000)), 'd')
            if model == "best":
                load = "train_log/hard_mining3/model-1360500"
            elif model == "nohardexamples":
                load = "train_log/condrcnn_all_2gpu_lrreduce2/model-1200500"
            elif model == "newrpn":
                load = "train_log/newrpn1/model"
            elif model == "resnet50_nohardexamples":
                load = "train_log/condrcnn_all_resnet50/model-1200500"
                cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3]
            elif model == "resnet50":
                load = "train_log/hard_mining3_resnet50/model-1360500"
                cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3]
            elif model == "gotonly":
                load = "train_log/hard_mining3_onlygot/model-1361000"
            elif model.startswith("checkpoint:"):
                load = model.replace("checkpoint:", "")
            else:
                assert False, ("unknown model", model)
            from dataset import DetectionDataset
            # init tensorpack model
            # cfg.freeze(False)
            DetectionDataset(
            )  # initialize the config with information from our dataset

            cfg.EXTRACT_GT_FEATURES = True
            cfg.MODE_TRACK = False
            extract_model = ResNetFPNModel()
            extract_ff_feats_cfg = PredictConfig(
                model=extract_model,
                session_init=get_model_loader(load),
                input_names=['image', 'roi_boxes'],
                output_names=['rpn/feature'])
            finalize_configs(is_training=False)
            self._extract_func = OfflinePredictor(extract_ff_feats_cfg)

            cfg.EXTRACT_GT_FEATURES = False
            cfg.MODE_TRACK = True
            cfg.USE_PRECOMPUTED_REF_FEATURES = True
            self._pred_func = self._make_pred_func(load)
 def __init__(self):
     super().__init__(name='ArgmaxTracker', is_deterministic=True)
     self._ref_img = None
     self._ref_bbox = None
     self._prev_box = None
     model = self._init_model()
     load = "train_log/condrcnn_onlygot/model-460000"
     predcfg = PredictConfig(
         model=model,
         session_init=get_model_loader(load),
         input_names=model.get_inference_tensor_names()[0],
         output_names=model.get_inference_tensor_names()[1])
     self._pred_func = OfflinePredictor(predcfg)
示例#9
0
def _test():
    import numpy as np
    # from tensorpack.tfutils import TowerContext
    from tensorpack import PredictConfig, OfflinePredictor

    pretrained = False

    models = [
        resnet10,
        resnet12,
        resnet14,
        resnet16,
        resnet18_wd4,
        resnet18_wd2,
        resnet18_w3d4,
        resnet18,
        resnet34,
        resnet50,
        resnet50b,
        resnet101,
        resnet101b,
        resnet152,
        resnet152b,
        resnet200,
        resnet200b,
        seresnet18,
        seresnet34,
        seresnet50,
        seresnet50b,
        seresnet101,
        seresnet101b,
        seresnet152,
        seresnet152b,
        seresnet200,
        seresnet200b,
    ]

    for model in models:

        net = model(pretrained=pretrained)

        pred_config = PredictConfig(session_init=None,
                                    model=net,
                                    input_names=['input'],
                                    output_names=['label'])

        pred = OfflinePredictor(pred_config)
        img = np.zeros((224, 224, 3), np.uint8)
        prediction = pred([img])[0]
        print(prediction)
        pass
示例#10
0
def test(args):
    from imageio import imsave
    from tictoc import Timer
    data_folder = args.get("data_folder")
    image_size = args.get("image_size")
    batch_size = args.get("batch_size") or BATCH
    test_ckpt = args.get("test_ckpt")
    test_folder = args.get("test_folder")
    if not os.path.exists(test_folder):
        os.makedirs(test_folder)
    pred_config = PredictConfig(
        model=Style2PO(args),
        session_init=SmartInit(test_ckpt),
        input_names=["image_input", "image_target"],
        output_names=['stages-target/viz', 'loss_output'])
    predictor = OfflinePredictor(pred_config)
    zmin, zmax = (0, 1) if args.get("act_input") == "identity" else (-1, 1)
    test_ds = get_data(data_folder, image_size, False, zmin, zmax, batch_size)
    test_ds.reset_state()
    idx = 0
    losses = list()
    print("------------------ predict --------------")
    timer = Timer("predict", tic=True, show=Timer.STDOUT)
    for rz, it in test_ds:
        output_array, loss_output = predictor(rz, it)
        if output_array.ndim == 4:
            for i in range(output_array.shape[0]):
                imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)),
                       output_array[i])
                idx += 1
        else:
            imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)),
                   output_array)
            idx += 1
        losses.append(loss_output)
        print("loss #", idx, "=", loss_output)
    timer.toc(Timer.STDOUT)
    print("Test and save", idx, "images to", test_folder, "avg loss =",
          np.mean(losses))
示例#11
0
def export_eval_protobuf_model(checkpoint_dir, model_name, dataset, quant_type,
                               output_file, batch_size):
    _, test_data, (img_shape, label_shape) = datasets.DATASETS[dataset]()

    model_func, input_spec, output_spec = get_model_func(
        "eval", model_name, quant_type, img_shape, label_shape[0])
    input_names = [i.name for i in input_spec]
    output_names = [o.name for o in output_spec]
    predictor_config = PredictConfig(session_init=SaverRestore(checkpoint_dir +
                                                               "/checkpoint"),
                                     tower_func=model_func,
                                     input_signature=input_spec,
                                     input_names=input_names,
                                     output_names=output_names,
                                     create_graph=False)

    print("Exporting optimised protobuf graph...")
    K.set_learning_phase(False)
    ModelExporter(predictor_config).export_compact(output_file, optimize=False)

    K.clear_session()
    pred = OfflinePredictor(predictor_config)

    test_data = BatchData(test_data, batch_size, remainder=True)
    test_data.reset_state()

    num_correct = 0
    num_processed = 0
    for img, label in tqdm(test_data):
        num_correct += sum(pred(img)[0].argmax(axis=1) == label.argmax(axis=1))
        num_processed += img.shape[0]

    print("Exported model has accuracy {:.4f}".format(num_correct /
                                                      num_processed))

    return input_names, output_names, {i.name: i.shape for i in input_spec}
示例#12
0
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    # ROM_FILE = args.rom
    METHOD = args.algo
    # set num_actions
    init_player = MedicalPlayer(directory=data_dir,
                                files_list=test_list,
                                screen_dims=IMAGE_SIZE,
                                spacing=SPACING)
    NUM_ACTIONS = init_player.action_space.n
    num_validation_files = init_player.files.num_files

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(PredictConfig(
            model=Model(),
            session_init=get_model_loader(args.load),
            input_names=['state'],
            output_names=['Qvalue']))
        if args.task == 'play':
            t0 = time.time()
            play_n_episodes(get_player(directory=data_dir,
                                       files_list=test_list, viz=0.01,
                                       saveGif=args.saveGif,
                                       saveVideo=args.saveVideo),
                            pred, num_validation_files)

            t1 = time.time()
            print(t1-t0)
        elif args.task == 'eval':
            eval_model_multithread(pred, EVAL_EPISODE, get_player)
    else:
示例#13
0
    ##########################################################
    #initialize states and Qvalues for the various agents
    state_names = []
    qvalue_names = []
    for i in range(0, args.agents):
        state_names.append('state_{}'.format(i))
        qvalue_names.append('Qvalue_{}'.format(i))

############################################################

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(
            PredictConfig(model=Model(agents=args.agents),
                          session_init=get_model_loader(args.load),
                          input_names=state_names,
                          output_names=qvalue_names))
        # demo pretrained model one episode at a time
        if args.task == 'play':
            play_n_episodes(get_player(files_list=args.files,
                                       viz=0.01,
                                       saveGif=args.saveGif,
                                       saveVideo=args.saveVideo,
                                       task='play',
                                       agents=args.agents),
                            pred,
                            num_files,
                            agents=args.agents,
                            reward_strategy=args.reward_strategy)
        # run episodes in parallel and evaluate pretrained model
示例#14
0
    METHOD = args.algo
    # load files into env to set num_actions, num_validation_files
    init_player = MedicalPlayer(
        files_list=args.files,  #files_list=files_list,
        data_type=args.type,
        screen_dims=IMAGE_SIZE,
        task='play')
    NUM_ACTIONS = init_player.action_space.n
    num_files = init_player.files.num_files

    if args.task != 'train':
        assert args.load is not None
        pred = OfflinePredictor(
            PredictConfig(model=Model(IMAGE_SIZE, FRAME_HISTORY, METHOD,
                                      NUM_ACTIONS, GAMMA, args.trainable),
                          session_init=get_model_loader(args.load),
                          input_names=['state'],
                          output_names=['Qvalue']))
        # demo pretrained model one episode at a time
        if args.task == 'play':
            play_n_episodes(get_player(files_list=args.files,
                                       data_type=args.type,
                                       viz=0,
                                       saveGif=args.saveGif,
                                       saveVideo=args.saveVideo,
                                       task='play'),
                            pred,
                            num_files,
                            viewer=None)

        # run episodes in parallel and evaluate pretrained model
示例#15
0
class DRLDataGenerator():
    def __init__(self, game_name, config, global_model_data_path, local_test_flag):
        if not local_test_flag:
            mkdirs(global_model_data_path+config.DRL.Learn.data_save_path)
        self.game_name = game_name
        self.data_save_path = global_model_data_path+config.DRL.Learn.data_save_path
        self.config = config
        self.global_iter = 0
        self.ckpt_dir = global_model_data_path+self.config.DRL.Learn.ckpt_dir
        self.ckpt_save_iter = self.config.DRL.Learn.ckpt_save_iter
        if not local_test_flag:
            mkdirs(self.ckpt_dir)

        self.apply_prioritize_memory = False
        if self.apply_prioritize_memory:
            self.memory = PrioritizedReplay(capacity=self.config.DRL.Learn.replay_memory_size)
        else:
            # store the previous observations in replay memory
            self.memory = deque()

        if self.game_name == 'flappybird':  # FlappyBird applies pytorch
            use_cuda = config.DRL.Learn.cuda and torch.cuda.is_available()
            self.device = 'cuda' if use_cuda else 'cpu'
            self.actions_number = self.config.DRL.Learn.actions
            # self.nn = DeepQNetwork().to(self.device)
            self.nn = FlappyBirdDQN().to(self.device)
            self.optim = optim.Adam(self.nn.parameters(), lr=self.config.DRL.Learn.learning_rate)
            if config.DRL.Learn.ckpt_load:
                self.load_checkpoint(model_name='flappy_bird_model')
            self.game_state = FlappyBird()
            if torch.cuda.is_available():
                torch.cuda.manual_seed(123)
            else:
                torch.manual_seed(123)
        elif self.game_name == 'Assault-v0' or self.game_name == 'Breakout-v0' or self.game_name == 'SpaceInvaders-v0':
            if self.game_name == 'Assault-v0':
                game_model_name = 'Assault-v0.tfmodel'
            elif self.game_name == 'Breakout-v0':
                game_model_name = 'Breakout-v0.npz'
            elif self.game_name == 'SpaceInvaders-v0':
                game_model_name = 'SpaceInvaders-v0.tfmodel'
            else:
                raise ValueError('Unknown game name {0}'.format(self.game_name))

            self.env = self.get_player_atari(train=False)
            num_actions = self.env .action_space.n
            self.nn = OfflinePredictor(PredictConfig(
                # model=Model(),
                model=Model(num_actions=num_actions, image_size=(84, 84)),
                session_init=SmartInit(self.ckpt_dir+game_model_name),
                input_names=['state'],
                output_names=['policy', 'pred_value']))

            self.config.DRL.Learn.actions = num_actions
            # for more about A3C training, please refer to https://github.com/tensorpack/

        # torch.save(self.nn.state_dict(), "{}/flappy_bird_state".format(self.config.DRL.Learn.ckpt_dir))
        # self.trainTransform = tv.transforms.Compose([tv.transforms.Resize(size=(80, 80)),
        #                                              tv.transforms.Grayscale(num_output_channels=1),
        #                                              tv.transforms.ToTensor(),
        #                                              # tv.transforms.Normalize()
        #                                              ])

    def append_sample(self, s_t, epsilon, time_step, discount_factor=1):

        with torch.no_grad():
            readout_t0 = self.nn(s_t.unsqueeze(0))
        readout_t0 = readout_t0.cpu().numpy()
        a_t = np.zeros([self.config.DRL.Learn.actions])
        if time_step % self.config.DRL.Learn.frame_per_action != 0 and time_step < self.config.DRL.Learn.explore:
            action_index = 0
        else:
            if random.random() <= epsilon:
                print("----------Random Action----------")
                action_index = random.randrange(self.config.DRL.Learn.actions)
            else:
                # values, indices = torch.max(readout_t, 1)
                action_index = np.argmax(readout_t0)
            # print('action is {0}'.format(str(a_t)))
        a_t[action_index] = 1
        # scale down epsilon
        if epsilon > self.config.DRL.Learn.final_epsilon and time_step > self.config.DRL.Learn.observe:
            epsilon -= (
                               self.config.DRL.Learn.initial_epsilon - self.config.DRL.Learn.final_epsilon) / self.config.DRL.Learn.explore
        x_t1_colored, r_t, terminal = self.game_state.next_frame(action_index)
        x_t1 = handle_image_input(x_t1_colored[:self.game_state.screen_width, :int(self.game_state.base_y)])
        # save_image_path=self.image_save_path, iter=time_step).to(self.device)
        # s_t1 = torch.stack(tensors=[x_t1], dim=0).to(self.device)
        s_t1 = torch.cat((s_t[1:, :, :], x_t1.to(self.device)))
        with torch.no_grad():
            readout_t1 = self.nn(s_t1.unsqueeze(0)).cpu().numpy()

        if self.apply_prioritize_memory:
            old_val = readout_t0[0][action_index]
            if terminal:
                readout_t0_update = r_t
            else:
                readout_t0_update = r_t + discount_factor * max(readout_t1[0])

            error = abs(old_val - readout_t0_update)
            self.memory.add(error, (s_t, torch.tensor(a_t, dtype=torch.float32),
                                    torch.tensor([r_t], dtype=torch.float32), s_t1, terminal))
        else:
            # store the transition in D
            self.memory.append((s_t, torch.tensor(a_t, dtype=torch.float32),
                                torch.tensor([r_t], dtype=torch.float32), s_t1, terminal))
            if len(self.memory) > self.config.DRL.Learn.replay_memory_size:
                self.memory.popleft()

        return readout_t0, s_t1, action_index, r_t, epsilon

    def save_checkpoint(self, time_step):
        model_states = {'FlappyBirdDQN': self.nn.state_dict()}
        optim_states = {'optim_DQN': self.optim.state_dict()}
        states = {'iter': time_step,
                  'model_states': model_states,
                  'optim_states': optim_states}

        filepath = os.path.join(self.ckpt_dir, "flappy_bird_model")
        with open(filepath, 'wb+') as f:
            torch.save(states, f)

    def load_checkpoint(self, model_name):
        filepath = os.path.join(self.ckpt_dir, model_name)
        if os.path.isfile(filepath):
            if self.device =='cuda':
                with open(filepath, 'rb') as f:
                    checkpoint = torch.load(f)
            else:
                with open(filepath, 'rb') as f:
                    checkpoint = torch.load(f, map_location=torch.device('cpu'))
            self.global_iter = checkpoint['iter']
            self.nn.load_state_dict(checkpoint['model_states']['FlappyBirdDQN'])
            self.optim.load_state_dict(checkpoint['optim_states']['optim_DQN'])

    def sample_batch(self):
        if self.apply_prioritize_memory:
            minibatch, idxs, is_weights = self.memory.sample(self.config.DRL.Learn.batch)
        else:
            minibatch = random.sample(self.memory, self.config.DRL.Learn.batch)
            is_weights = np.ones(shape=[self.config.DRL.Learn.batch])
            idxs = None

        return minibatch, idxs, is_weights

    def get_player_atari(self, train=False, dumpdir=None):
        env = gym.make(self.game_name)
        if dumpdir:
            env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True)
        env = FireResetEnv(env)
        env = MapState(env, lambda im: cv2.resize(im, (84, 84)))
        env = FrameStack(env, 4)
        if train:
            env = LimitLength(env, 60000)
        return env

    def test_model_and_generate_data(self, test_size=50000):
        if self.game_name == "flappybird":
            with open(self.data_save_path+'action_values.txt', 'w')as action_values_file:
                action_index = 0
                x_t0_colored, r_t, terminal = self.game_state.next_frame(action_index)
                x_t0 = handle_image_input(x_t0_colored[:self.game_state.screen_width, :int(self.game_state.base_y)])
                s_t0 = torch.cat(tuple(x_t0 for _ in range(4))).to(self.device)
                # self.global_iter += 1
                while self.global_iter < test_size:
                    with torch.no_grad():
                        readout = self.nn(s_t0.unsqueeze(0))
                    readout = readout.cpu().numpy()
                    action_index = np.argmax(readout)
                    x_t1_colored, r_t, terminal = self.game_state.next_frame(action_index)
                    # store_state_action_data(img_colored=x_t0_colored[:self.game_state.screen_width, :int(self.game_state.base_y)],
                    #                         action_values=readout[0], reward=r_t, action_index=action_index,
                    #                         save_image_path=self.data_save_path, action_values_file=action_values_file,
                    #                         game_name=self.game_name, iteration_number=self.global_iter)
                    print("finishing save data iter {0}".format(self.global_iter))
                    x_t1 = handle_image_input(x_t1_colored[:self.game_state.screen_width, :int(self.game_state.base_y)])
                    s_t1 = torch.cat((s_t0[1:, :, :], x_t1.to(self.device)))
                    s_t0 = s_t1
                    x_t0_colored = x_t1_colored
                    self.global_iter += 1

        elif self.game_name == "Assault-v0" or self.game_name == "Breakout-v0" or self.game_name == "SpaceInvaders-v0":
            next_game_flag = True
            with open(self.data_save_path + 'action_values.txt', 'w')as action_values_file:
                while next_game_flag:
                    def predict(s):
                        """
                        Map from observation to action, with 0.01 greedy.
                        """
                        s = np.expand_dims(s, 0)  # batch
                        act = self.nn(s)[0][0].argmax()
                        value = self.nn(s)[1]
                        if random.random() < 0.01:
                            spc = self.env.action_space
                            act = spc.sample()
                        return act, value

                    s_t0 = self.env.reset()
                    sum_r = 0
                    while True:
                        act, value = predict(s_t0)
                        s_t1, r_t, isOver, info = self.env.step(act)
                        # if render:
                        #     self.env.render()
                        store_state_action_data(img_colored=s_t0[:,:,:, -2],
                                                action_values=value, reward=r_t, action_index=act,
                                                save_image_path=self.data_save_path, action_values_file=action_values_file,
                                                game_name=self.game_name, iteration_number=self.global_iter)

                        sum_r += r_t
                        s_t0 = s_t1
                        self.global_iter += 1
                        print("finishing save data iter {0}".format(self.global_iter))

                        if self.global_iter >= test_size:
                            next_game_flag = False
                            break

                        if isOver:
                            print ("Game is over with reward {0}".format(sum_r))
                            break

            # play_n_episodes(self.get_player_atari(train=False), self.nn, test_size, render=True)

    def train_DRl_model(self):
        # get the first state by doing nothing and preprocess the image to 80x80x4
        x_t0_colored, r_0, terminal = self.game_state.next_frame(0)
        x_t = handle_image_input(x_t0_colored[:self.game_state.screen_width, :int(self.game_state.base_y)])
        # s_t = torch.stack(tensors=[x_t], dim=0)
        s_t = torch.cat(tuple(x_t for _ in range(4))).to(self.device)

        # start training
        epsilon = self.config.DRL.Learn.initial_epsilon
        while "flappy bird" != "angry bird":
            # choose an action epsilon greedily
            readout_t0, s_t1, action_index, r_t, epsilon = self.append_sample(s_t, epsilon, self.global_iter)

            # only train if done observing
            if self.global_iter > self.config.DRL.Learn.observe:
                # sample a minibatch to train on
                minibatch, idxs, is_weights = self.sample_batch()

                # get the batch variables
                s_t_batch = torch.stack([d[0] for d in minibatch]).to(self.device)
                a_batch = torch.stack([d[1] for d in minibatch]).to(self.device)
                r_batch = torch.stack([d[2] for d in minibatch]).to(self.device)
                s_t1_batch = torch.stack([d[3] for d in minibatch]).to(self.device)

                y_batch = []
                # readout_j1_batch = readout.eval(feed_dict={s: s_j1_batch})
                readout_t1_batch = self.nn(s_t1_batch)
                readout_t0_batch = self.nn(s_t_batch)
                for i in range(0, len(minibatch)):
                    terminal = minibatch[i][4]
                    # if terminal, only equals reward
                    if terminal:
                        y_batch.append(r_batch[i])
                    else:
                        max_readout_t1_batch = torch.max(readout_t1_batch[i], dim=0)[0]
                        y_batch.append(r_batch[i] + self.config.DRL.Learn.gamma * max_readout_t1_batch)
                readout_action = torch.sum(torch.mul(readout_t0_batch, a_batch), dim=1)

                y_batch = torch.stack(y_batch).squeeze()
                errors = torch.abs(readout_action - y_batch).data.cpu().numpy()
                # update priority
                if self.apply_prioritize_memory:
                    for i in range(self.config.DRL.Learn.batch):
                        idx = idxs[i]
                        self.memory.update(idx, errors[i])

                DRL_loss = (torch.FloatTensor(is_weights).to(self.device) * tnf.mse_loss(readout_action,
                                                                                         y_batch)).mean()

                # DRL_loss = square_loss(x=readout_action, y=y_batch)
                self.optim.zero_grad()
                DRL_loss.backward(retain_graph=True)
                self.optim.step()

            # update the old values
            s_t = s_t1
            self.global_iter += 1

            # save progress every 10000 iterations
            # if self.global_iter % self.ckpt_save_iter == 0:
            #     print('Saving VAE models')
            #     self.save_checkpoint('DRL-' + str(self.global_iter), verbose=True)

            # print info
            state = ""
            if self.global_iter <= self.config.DRL.Learn.observe:
                state = "observe"
            elif self.config.DRL.Learn.observe < self.global_iter <= self.config.DRL.Learn.observe + self.config.DRL.Learn.explore:
                state = "explore"
            else:
                state = "train"

            print("TIMESTEP", self.global_iter, "/ STATE", state, "/ EPSILON", epsilon, "/ ACTION", action_index,
                  "/ REWARD",
                  r_t, "/ Q_MAX %e" % np.max(readout_t0))
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('image_path')
    parser.add_argument('--model_path', default='log/checkpoint')
    parser.add_argument('--output_path', default='figures/')
    parser.add_argument('--size', type=int, default=32)
    args = parser.parse_args()

    np.random.seed(0)
    # initialize the model
    predict_func = OfflinePredictor(
        PredictConfig(inputs_desc=[
            InputDesc(tf.float32, [None, INPUT_SIZE, INPUT_SIZE, 2],
                      'input_image')
        ],
                      tower_func=model.feedforward,
                      session_init=SaverRestore(args.model_path),
                      input_names=['input_image'],
                      output_names=['prob']))

    # simulate suda's gridworld input
    image = cv2.imread(
        args.image_path,
        cv2.IMREAD_GRAYSCALE)  # 0 if obstacle, 255 if free space
    h, w = image.shape[:2]
    obj = img2obj(image)  # list containing row major indices of objects

    # specify position is recent memory
    radius = 6
    #s = [340/2, 110/2]  # needs to be a list