def __init__(self, game_name, config, global_model_data_path, local_test_flag): if not local_test_flag: mkdirs(global_model_data_path+config.DRL.Learn.data_save_path) self.game_name = game_name self.data_save_path = global_model_data_path+config.DRL.Learn.data_save_path self.config = config self.global_iter = 0 self.ckpt_dir = global_model_data_path+self.config.DRL.Learn.ckpt_dir self.ckpt_save_iter = self.config.DRL.Learn.ckpt_save_iter if not local_test_flag: mkdirs(self.ckpt_dir) self.apply_prioritize_memory = False if self.apply_prioritize_memory: self.memory = PrioritizedReplay(capacity=self.config.DRL.Learn.replay_memory_size) else: # store the previous observations in replay memory self.memory = deque() if self.game_name == 'flappybird': # FlappyBird applies pytorch use_cuda = config.DRL.Learn.cuda and torch.cuda.is_available() self.device = 'cuda' if use_cuda else 'cpu' self.actions_number = self.config.DRL.Learn.actions # self.nn = DeepQNetwork().to(self.device) self.nn = FlappyBirdDQN().to(self.device) self.optim = optim.Adam(self.nn.parameters(), lr=self.config.DRL.Learn.learning_rate) if config.DRL.Learn.ckpt_load: self.load_checkpoint(model_name='flappy_bird_model') self.game_state = FlappyBird() if torch.cuda.is_available(): torch.cuda.manual_seed(123) else: torch.manual_seed(123) elif self.game_name == 'Assault-v0' or self.game_name == 'Breakout-v0' or self.game_name == 'SpaceInvaders-v0': if self.game_name == 'Assault-v0': game_model_name = 'Assault-v0.tfmodel' elif self.game_name == 'Breakout-v0': game_model_name = 'Breakout-v0.npz' elif self.game_name == 'SpaceInvaders-v0': game_model_name = 'SpaceInvaders-v0.tfmodel' else: raise ValueError('Unknown game name {0}'.format(self.game_name)) self.env = self.get_player_atari(train=False) num_actions = self.env .action_space.n self.nn = OfflinePredictor(PredictConfig( # model=Model(), model=Model(num_actions=num_actions, image_size=(84, 84)), session_init=SmartInit(self.ckpt_dir+game_model_name), input_names=['state'], output_names=['policy', 'pred_value'])) self.config.DRL.Learn.actions = num_actions
def test_only(args): from imageio import imsave data_folder = args.get("data_folder") test_ckpt = args.get("test_ckpt") test_folder = args.get("test_folder") if not os.path.exists(test_folder): os.makedirs(test_folder) image_size = 224 pred_config = PredictConfig( model=ProgressiveSynTex(args), session_init=SmartInit(test_ckpt), input_names=["pre_image_input", "image_target"], output_names=['stages-target/viz', 'loss_output'] ) predictor = OfflinePredictor(pred_config) test_ds = get_data(data_folder, image_size, isTrain=False) test_ds.reset_state() idx = 0 losses = list() print("------------------ predict --------------") for pii, it in test_ds: output_array, loss_output = predictor(pii, it) if output_array.ndim == 4: for i in range(output_array.shape[0]): imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)), output_array[i]) idx += 1 else: imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)), output_array) idx += 1 losses.append(loss_output) print("loss #", idx, "=", loss_output) print("Test and save", idx, "images to", test_folder, "avg loss =", np.mean(losses))
def init(args=None, is_running=0, pt=None): global ckpt2 # 网络 model = Net2() if is_running == 1: if pt == "": ckpt2 = tf.train.latest_checkpoint(logdir2) else: ckpt2 = '{}/{}'.format(logdir2, pt) else: ckpt2 = '{}/{}'.format( logdir2, args.ckpt) if args.ckpt else tf.train.latest_checkpoint(logdir2) session_inits = [] if ckpt2: session_inits.append(SaverRestore(ckpt2)) pred_conf = PredictConfig( model=model, input_names=['x_ppgs', 'x_mfccs', 'y_spec', 'y_mel'], output_names=['pred_spec', "ppgs"], session_init=ChainInit(session_inits)) global predictor predictor = OfflinePredictor(pred_conf) if is_running == 1: return jsonify({"code": 0, "ckpt": ckpt2})
def thread_function(self): """ Run on secondary thread """ pred = OfflinePredictor( PredictConfig(model=Model(IMAGE_SIZE, FRAME_HISTORY, self.METHOD, self.NUM_ACTIONS, GAMMA, ""), session_init=get_model_loader(self.fname_model.name), input_names=['state'], output_names=['Qvalue'])) # demo pretrained model one episode at a time if self.task_value == 'Play': play_n_episodes(get_player(files_list=self.selected_list, viz=0.01, data_type=self.window.usecase, saveGif=self.GIF_value, saveVideo=self.video_value, task='play'), pred, self.num_files, viewer=self.window) # run episodes in parallel and evaluate pretrained model elif self.task_value == 'Evaluation': play_n_episodes(get_player(files_list=self.selected_list, viz=0.01, data_type=self.window.usecase, saveGif=self.GIF_value, saveVideo=self.video_value, task='eval'), pred, self.num_files, viewer=self.window)
def _make_pred_func(self, load): from train import ResNetFPNTrackModel pred_model = ResNetFPNTrackModel() predcfg = PredictConfig( model=pred_model, session_init=get_model_loader(load), input_names=pred_model.get_inference_tensor_names()[0], output_names=pred_model.get_inference_tensor_names()[1]) return OfflinePredictor(predcfg)
def load_model(self): print('Loading Model...') model_path = self.model_path model_constructor = self.get_model() pred_config = PredictConfig(model=model_constructor( self.nr_types, self.input_shape, self.mask_shape, self.input_norm), session_init=get_model_loader(model_path), input_names=self.input_tensor_names, output_names=self.output_tensor_names) self.predictor = OfflinePredictor(pred_config)
def __init__(self, name, need_network=True, need_img=True, model="best"): super().__init__(name=name, is_deterministic=True) self._resizer = CustomResize(cfg.PREPROC.TEST_SHORT_EDGE_SIZE, cfg.PREPROC.MAX_SIZE) self._prev_box = None self._ff_gt_feats = None self._need_network = need_network self._need_img = need_img self._rotated_bbox = None if need_network: logger.set_logger_dir( "/tmp/test_log_/" + str(random.randint(0, 10000)), 'd') if model == "best": load = "train_log/hard_mining3/model-1360500" elif model == "nohardexamples": load = "train_log/condrcnn_all_2gpu_lrreduce2/model-1200500" elif model == "newrpn": load = "train_log/newrpn1/model" elif model == "resnet50_nohardexamples": load = "train_log/condrcnn_all_resnet50/model-1200500" cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3] elif model == "resnet50": load = "train_log/hard_mining3_resnet50/model-1360500" cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3] elif model == "gotonly": load = "train_log/hard_mining3_onlygot/model-1361000" elif model.startswith("checkpoint:"): load = model.replace("checkpoint:", "") else: assert False, ("unknown model", model) from dataset import DetectionDataset # init tensorpack model # cfg.freeze(False) DetectionDataset( ) # initialize the config with information from our dataset cfg.EXTRACT_GT_FEATURES = True cfg.MODE_TRACK = False extract_model = ResNetFPNModel() extract_ff_feats_cfg = PredictConfig( model=extract_model, session_init=get_model_loader(load), input_names=['image', 'roi_boxes'], output_names=['rpn/feature']) finalize_configs(is_training=False) self._extract_func = OfflinePredictor(extract_ff_feats_cfg) cfg.EXTRACT_GT_FEATURES = False cfg.MODE_TRACK = True cfg.USE_PRECOMPUTED_REF_FEATURES = True self._pred_func = self._make_pred_func(load)
def __init__(self): super().__init__(name='ArgmaxTracker', is_deterministic=True) self._ref_img = None self._ref_bbox = None self._prev_box = None model = self._init_model() load = "train_log/condrcnn_onlygot/model-460000" predcfg = PredictConfig( model=model, session_init=get_model_loader(load), input_names=model.get_inference_tensor_names()[0], output_names=model.get_inference_tensor_names()[1]) self._pred_func = OfflinePredictor(predcfg)
def _test(): import numpy as np # from tensorpack.tfutils import TowerContext from tensorpack import PredictConfig, OfflinePredictor pretrained = False models = [ resnet10, resnet12, resnet14, resnet16, resnet18_wd4, resnet18_wd2, resnet18_w3d4, resnet18, resnet34, resnet50, resnet50b, resnet101, resnet101b, resnet152, resnet152b, resnet200, resnet200b, seresnet18, seresnet34, seresnet50, seresnet50b, seresnet101, seresnet101b, seresnet152, seresnet152b, seresnet200, seresnet200b, ] for model in models: net = model(pretrained=pretrained) pred_config = PredictConfig(session_init=None, model=net, input_names=['input'], output_names=['label']) pred = OfflinePredictor(pred_config) img = np.zeros((224, 224, 3), np.uint8) prediction = pred([img])[0] print(prediction) pass
def test(args): from imageio import imsave from tictoc import Timer data_folder = args.get("data_folder") image_size = args.get("image_size") batch_size = args.get("batch_size") or BATCH test_ckpt = args.get("test_ckpt") test_folder = args.get("test_folder") if not os.path.exists(test_folder): os.makedirs(test_folder) pred_config = PredictConfig( model=Style2PO(args), session_init=SmartInit(test_ckpt), input_names=["image_input", "image_target"], output_names=['stages-target/viz', 'loss_output']) predictor = OfflinePredictor(pred_config) zmin, zmax = (0, 1) if args.get("act_input") == "identity" else (-1, 1) test_ds = get_data(data_folder, image_size, False, zmin, zmax, batch_size) test_ds.reset_state() idx = 0 losses = list() print("------------------ predict --------------") timer = Timer("predict", tic=True, show=Timer.STDOUT) for rz, it in test_ds: output_array, loss_output = predictor(rz, it) if output_array.ndim == 4: for i in range(output_array.shape[0]): imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)), output_array[i]) idx += 1 else: imsave(os.path.join(test_folder, "test-{}.jpg".format(idx)), output_array) idx += 1 losses.append(loss_output) print("loss #", idx, "=", loss_output) timer.toc(Timer.STDOUT) print("Test and save", idx, "images to", test_folder, "avg loss =", np.mean(losses))
def export_eval_protobuf_model(checkpoint_dir, model_name, dataset, quant_type, output_file, batch_size): _, test_data, (img_shape, label_shape) = datasets.DATASETS[dataset]() model_func, input_spec, output_spec = get_model_func( "eval", model_name, quant_type, img_shape, label_shape[0]) input_names = [i.name for i in input_spec] output_names = [o.name for o in output_spec] predictor_config = PredictConfig(session_init=SaverRestore(checkpoint_dir + "/checkpoint"), tower_func=model_func, input_signature=input_spec, input_names=input_names, output_names=output_names, create_graph=False) print("Exporting optimised protobuf graph...") K.set_learning_phase(False) ModelExporter(predictor_config).export_compact(output_file, optimize=False) K.clear_session() pred = OfflinePredictor(predictor_config) test_data = BatchData(test_data, batch_size, remainder=True) test_data.reset_state() num_correct = 0 num_processed = 0 for img, label in tqdm(test_data): num_correct += sum(pred(img)[0].argmax(axis=1) == label.argmax(axis=1)) num_processed += img.shape[0] print("Exported model has accuracy {:.4f}".format(num_correct / num_processed)) return input_names, output_names, {i.name: i.shape for i in input_spec}
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu # ROM_FILE = args.rom METHOD = args.algo # set num_actions init_player = MedicalPlayer(directory=data_dir, files_list=test_list, screen_dims=IMAGE_SIZE, spacing=SPACING) NUM_ACTIONS = init_player.action_space.n num_validation_files = init_player.files.num_files if args.task != 'train': assert args.load is not None pred = OfflinePredictor(PredictConfig( model=Model(), session_init=get_model_loader(args.load), input_names=['state'], output_names=['Qvalue'])) if args.task == 'play': t0 = time.time() play_n_episodes(get_player(directory=data_dir, files_list=test_list, viz=0.01, saveGif=args.saveGif, saveVideo=args.saveVideo), pred, num_validation_files) t1 = time.time() print(t1-t0) elif args.task == 'eval': eval_model_multithread(pred, EVAL_EPISODE, get_player) else:
########################################################## #initialize states and Qvalues for the various agents state_names = [] qvalue_names = [] for i in range(0, args.agents): state_names.append('state_{}'.format(i)) qvalue_names.append('Qvalue_{}'.format(i)) ############################################################ if args.task != 'train': assert args.load is not None pred = OfflinePredictor( PredictConfig(model=Model(agents=args.agents), session_init=get_model_loader(args.load), input_names=state_names, output_names=qvalue_names)) # demo pretrained model one episode at a time if args.task == 'play': play_n_episodes(get_player(files_list=args.files, viz=0.01, saveGif=args.saveGif, saveVideo=args.saveVideo, task='play', agents=args.agents), pred, num_files, agents=args.agents, reward_strategy=args.reward_strategy) # run episodes in parallel and evaluate pretrained model
METHOD = args.algo # load files into env to set num_actions, num_validation_files init_player = MedicalPlayer( files_list=args.files, #files_list=files_list, data_type=args.type, screen_dims=IMAGE_SIZE, task='play') NUM_ACTIONS = init_player.action_space.n num_files = init_player.files.num_files if args.task != 'train': assert args.load is not None pred = OfflinePredictor( PredictConfig(model=Model(IMAGE_SIZE, FRAME_HISTORY, METHOD, NUM_ACTIONS, GAMMA, args.trainable), session_init=get_model_loader(args.load), input_names=['state'], output_names=['Qvalue'])) # demo pretrained model one episode at a time if args.task == 'play': play_n_episodes(get_player(files_list=args.files, data_type=args.type, viz=0, saveGif=args.saveGif, saveVideo=args.saveVideo, task='play'), pred, num_files, viewer=None) # run episodes in parallel and evaluate pretrained model
class DRLDataGenerator(): def __init__(self, game_name, config, global_model_data_path, local_test_flag): if not local_test_flag: mkdirs(global_model_data_path+config.DRL.Learn.data_save_path) self.game_name = game_name self.data_save_path = global_model_data_path+config.DRL.Learn.data_save_path self.config = config self.global_iter = 0 self.ckpt_dir = global_model_data_path+self.config.DRL.Learn.ckpt_dir self.ckpt_save_iter = self.config.DRL.Learn.ckpt_save_iter if not local_test_flag: mkdirs(self.ckpt_dir) self.apply_prioritize_memory = False if self.apply_prioritize_memory: self.memory = PrioritizedReplay(capacity=self.config.DRL.Learn.replay_memory_size) else: # store the previous observations in replay memory self.memory = deque() if self.game_name == 'flappybird': # FlappyBird applies pytorch use_cuda = config.DRL.Learn.cuda and torch.cuda.is_available() self.device = 'cuda' if use_cuda else 'cpu' self.actions_number = self.config.DRL.Learn.actions # self.nn = DeepQNetwork().to(self.device) self.nn = FlappyBirdDQN().to(self.device) self.optim = optim.Adam(self.nn.parameters(), lr=self.config.DRL.Learn.learning_rate) if config.DRL.Learn.ckpt_load: self.load_checkpoint(model_name='flappy_bird_model') self.game_state = FlappyBird() if torch.cuda.is_available(): torch.cuda.manual_seed(123) else: torch.manual_seed(123) elif self.game_name == 'Assault-v0' or self.game_name == 'Breakout-v0' or self.game_name == 'SpaceInvaders-v0': if self.game_name == 'Assault-v0': game_model_name = 'Assault-v0.tfmodel' elif self.game_name == 'Breakout-v0': game_model_name = 'Breakout-v0.npz' elif self.game_name == 'SpaceInvaders-v0': game_model_name = 'SpaceInvaders-v0.tfmodel' else: raise ValueError('Unknown game name {0}'.format(self.game_name)) self.env = self.get_player_atari(train=False) num_actions = self.env .action_space.n self.nn = OfflinePredictor(PredictConfig( # model=Model(), model=Model(num_actions=num_actions, image_size=(84, 84)), session_init=SmartInit(self.ckpt_dir+game_model_name), input_names=['state'], output_names=['policy', 'pred_value'])) self.config.DRL.Learn.actions = num_actions # for more about A3C training, please refer to https://github.com/tensorpack/ # torch.save(self.nn.state_dict(), "{}/flappy_bird_state".format(self.config.DRL.Learn.ckpt_dir)) # self.trainTransform = tv.transforms.Compose([tv.transforms.Resize(size=(80, 80)), # tv.transforms.Grayscale(num_output_channels=1), # tv.transforms.ToTensor(), # # tv.transforms.Normalize() # ]) def append_sample(self, s_t, epsilon, time_step, discount_factor=1): with torch.no_grad(): readout_t0 = self.nn(s_t.unsqueeze(0)) readout_t0 = readout_t0.cpu().numpy() a_t = np.zeros([self.config.DRL.Learn.actions]) if time_step % self.config.DRL.Learn.frame_per_action != 0 and time_step < self.config.DRL.Learn.explore: action_index = 0 else: if random.random() <= epsilon: print("----------Random Action----------") action_index = random.randrange(self.config.DRL.Learn.actions) else: # values, indices = torch.max(readout_t, 1) action_index = np.argmax(readout_t0) # print('action is {0}'.format(str(a_t))) a_t[action_index] = 1 # scale down epsilon if epsilon > self.config.DRL.Learn.final_epsilon and time_step > self.config.DRL.Learn.observe: epsilon -= ( self.config.DRL.Learn.initial_epsilon - self.config.DRL.Learn.final_epsilon) / self.config.DRL.Learn.explore x_t1_colored, r_t, terminal = self.game_state.next_frame(action_index) x_t1 = handle_image_input(x_t1_colored[:self.game_state.screen_width, :int(self.game_state.base_y)]) # save_image_path=self.image_save_path, iter=time_step).to(self.device) # s_t1 = torch.stack(tensors=[x_t1], dim=0).to(self.device) s_t1 = torch.cat((s_t[1:, :, :], x_t1.to(self.device))) with torch.no_grad(): readout_t1 = self.nn(s_t1.unsqueeze(0)).cpu().numpy() if self.apply_prioritize_memory: old_val = readout_t0[0][action_index] if terminal: readout_t0_update = r_t else: readout_t0_update = r_t + discount_factor * max(readout_t1[0]) error = abs(old_val - readout_t0_update) self.memory.add(error, (s_t, torch.tensor(a_t, dtype=torch.float32), torch.tensor([r_t], dtype=torch.float32), s_t1, terminal)) else: # store the transition in D self.memory.append((s_t, torch.tensor(a_t, dtype=torch.float32), torch.tensor([r_t], dtype=torch.float32), s_t1, terminal)) if len(self.memory) > self.config.DRL.Learn.replay_memory_size: self.memory.popleft() return readout_t0, s_t1, action_index, r_t, epsilon def save_checkpoint(self, time_step): model_states = {'FlappyBirdDQN': self.nn.state_dict()} optim_states = {'optim_DQN': self.optim.state_dict()} states = {'iter': time_step, 'model_states': model_states, 'optim_states': optim_states} filepath = os.path.join(self.ckpt_dir, "flappy_bird_model") with open(filepath, 'wb+') as f: torch.save(states, f) def load_checkpoint(self, model_name): filepath = os.path.join(self.ckpt_dir, model_name) if os.path.isfile(filepath): if self.device =='cuda': with open(filepath, 'rb') as f: checkpoint = torch.load(f) else: with open(filepath, 'rb') as f: checkpoint = torch.load(f, map_location=torch.device('cpu')) self.global_iter = checkpoint['iter'] self.nn.load_state_dict(checkpoint['model_states']['FlappyBirdDQN']) self.optim.load_state_dict(checkpoint['optim_states']['optim_DQN']) def sample_batch(self): if self.apply_prioritize_memory: minibatch, idxs, is_weights = self.memory.sample(self.config.DRL.Learn.batch) else: minibatch = random.sample(self.memory, self.config.DRL.Learn.batch) is_weights = np.ones(shape=[self.config.DRL.Learn.batch]) idxs = None return minibatch, idxs, is_weights def get_player_atari(self, train=False, dumpdir=None): env = gym.make(self.game_name) if dumpdir: env = gym.wrappers.Monitor(env, dumpdir, video_callable=lambda _: True) env = FireResetEnv(env) env = MapState(env, lambda im: cv2.resize(im, (84, 84))) env = FrameStack(env, 4) if train: env = LimitLength(env, 60000) return env def test_model_and_generate_data(self, test_size=50000): if self.game_name == "flappybird": with open(self.data_save_path+'action_values.txt', 'w')as action_values_file: action_index = 0 x_t0_colored, r_t, terminal = self.game_state.next_frame(action_index) x_t0 = handle_image_input(x_t0_colored[:self.game_state.screen_width, :int(self.game_state.base_y)]) s_t0 = torch.cat(tuple(x_t0 for _ in range(4))).to(self.device) # self.global_iter += 1 while self.global_iter < test_size: with torch.no_grad(): readout = self.nn(s_t0.unsqueeze(0)) readout = readout.cpu().numpy() action_index = np.argmax(readout) x_t1_colored, r_t, terminal = self.game_state.next_frame(action_index) # store_state_action_data(img_colored=x_t0_colored[:self.game_state.screen_width, :int(self.game_state.base_y)], # action_values=readout[0], reward=r_t, action_index=action_index, # save_image_path=self.data_save_path, action_values_file=action_values_file, # game_name=self.game_name, iteration_number=self.global_iter) print("finishing save data iter {0}".format(self.global_iter)) x_t1 = handle_image_input(x_t1_colored[:self.game_state.screen_width, :int(self.game_state.base_y)]) s_t1 = torch.cat((s_t0[1:, :, :], x_t1.to(self.device))) s_t0 = s_t1 x_t0_colored = x_t1_colored self.global_iter += 1 elif self.game_name == "Assault-v0" or self.game_name == "Breakout-v0" or self.game_name == "SpaceInvaders-v0": next_game_flag = True with open(self.data_save_path + 'action_values.txt', 'w')as action_values_file: while next_game_flag: def predict(s): """ Map from observation to action, with 0.01 greedy. """ s = np.expand_dims(s, 0) # batch act = self.nn(s)[0][0].argmax() value = self.nn(s)[1] if random.random() < 0.01: spc = self.env.action_space act = spc.sample() return act, value s_t0 = self.env.reset() sum_r = 0 while True: act, value = predict(s_t0) s_t1, r_t, isOver, info = self.env.step(act) # if render: # self.env.render() store_state_action_data(img_colored=s_t0[:,:,:, -2], action_values=value, reward=r_t, action_index=act, save_image_path=self.data_save_path, action_values_file=action_values_file, game_name=self.game_name, iteration_number=self.global_iter) sum_r += r_t s_t0 = s_t1 self.global_iter += 1 print("finishing save data iter {0}".format(self.global_iter)) if self.global_iter >= test_size: next_game_flag = False break if isOver: print ("Game is over with reward {0}".format(sum_r)) break # play_n_episodes(self.get_player_atari(train=False), self.nn, test_size, render=True) def train_DRl_model(self): # get the first state by doing nothing and preprocess the image to 80x80x4 x_t0_colored, r_0, terminal = self.game_state.next_frame(0) x_t = handle_image_input(x_t0_colored[:self.game_state.screen_width, :int(self.game_state.base_y)]) # s_t = torch.stack(tensors=[x_t], dim=0) s_t = torch.cat(tuple(x_t for _ in range(4))).to(self.device) # start training epsilon = self.config.DRL.Learn.initial_epsilon while "flappy bird" != "angry bird": # choose an action epsilon greedily readout_t0, s_t1, action_index, r_t, epsilon = self.append_sample(s_t, epsilon, self.global_iter) # only train if done observing if self.global_iter > self.config.DRL.Learn.observe: # sample a minibatch to train on minibatch, idxs, is_weights = self.sample_batch() # get the batch variables s_t_batch = torch.stack([d[0] for d in minibatch]).to(self.device) a_batch = torch.stack([d[1] for d in minibatch]).to(self.device) r_batch = torch.stack([d[2] for d in minibatch]).to(self.device) s_t1_batch = torch.stack([d[3] for d in minibatch]).to(self.device) y_batch = [] # readout_j1_batch = readout.eval(feed_dict={s: s_j1_batch}) readout_t1_batch = self.nn(s_t1_batch) readout_t0_batch = self.nn(s_t_batch) for i in range(0, len(minibatch)): terminal = minibatch[i][4] # if terminal, only equals reward if terminal: y_batch.append(r_batch[i]) else: max_readout_t1_batch = torch.max(readout_t1_batch[i], dim=0)[0] y_batch.append(r_batch[i] + self.config.DRL.Learn.gamma * max_readout_t1_batch) readout_action = torch.sum(torch.mul(readout_t0_batch, a_batch), dim=1) y_batch = torch.stack(y_batch).squeeze() errors = torch.abs(readout_action - y_batch).data.cpu().numpy() # update priority if self.apply_prioritize_memory: for i in range(self.config.DRL.Learn.batch): idx = idxs[i] self.memory.update(idx, errors[i]) DRL_loss = (torch.FloatTensor(is_weights).to(self.device) * tnf.mse_loss(readout_action, y_batch)).mean() # DRL_loss = square_loss(x=readout_action, y=y_batch) self.optim.zero_grad() DRL_loss.backward(retain_graph=True) self.optim.step() # update the old values s_t = s_t1 self.global_iter += 1 # save progress every 10000 iterations # if self.global_iter % self.ckpt_save_iter == 0: # print('Saving VAE models') # self.save_checkpoint('DRL-' + str(self.global_iter), verbose=True) # print info state = "" if self.global_iter <= self.config.DRL.Learn.observe: state = "observe" elif self.config.DRL.Learn.observe < self.global_iter <= self.config.DRL.Learn.observe + self.config.DRL.Learn.explore: state = "explore" else: state = "train" print("TIMESTEP", self.global_iter, "/ STATE", state, "/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, "/ Q_MAX %e" % np.max(readout_t0))
if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('image_path') parser.add_argument('--model_path', default='log/checkpoint') parser.add_argument('--output_path', default='figures/') parser.add_argument('--size', type=int, default=32) args = parser.parse_args() np.random.seed(0) # initialize the model predict_func = OfflinePredictor( PredictConfig(inputs_desc=[ InputDesc(tf.float32, [None, INPUT_SIZE, INPUT_SIZE, 2], 'input_image') ], tower_func=model.feedforward, session_init=SaverRestore(args.model_path), input_names=['input_image'], output_names=['prob'])) # simulate suda's gridworld input image = cv2.imread( args.image_path, cv2.IMREAD_GRAYSCALE) # 0 if obstacle, 255 if free space h, w = image.shape[:2] obj = img2obj(image) # list containing row major indices of objects # specify position is recent memory radius = 6 #s = [340/2, 110/2] # needs to be a list