Пример #1
0
    def __init__(self, display_size):
        pygame.init()

        self.surface = pygame.display.set_mode(display_size, 0, 24)
        pygame.display.set_caption('UNREAL')

        self.action_size = Environment.get_action_size(flags.env_type,
                                                       flags.env_name)
        self.objective_size = Environment.get_objective_size(
            flags.env_type, flags.env_name)
        self.global_network = UnrealModel(self.action_size,
                                          self.objective_size,
                                          -1,
                                          flags.use_lstm,
                                          flags.use_pixel_change,
                                          flags.use_value_replay,
                                          flags.use_reward_prediction,
                                          0.0,
                                          0.0,
                                          "/cpu:0",
                                          for_display=True)
        self.environment = Environment.create_environment(
            flags.env_type,
            flags.env_name,
            env_args={
                'episode_schedule': flags.split,
                'log_action_trace': flags.log_action_trace,
                'max_states_per_scene': flags.episodes_per_scene,
                'episodes_per_scene_test': flags.episodes_per_scene
            })
        self.font = pygame.font.SysFont(None, 20)
        self.value_history = ValueHistory()
        self.state_history = StateHistory()
        self.episode_reward = 0
Пример #2
0
  def __init__(self,
               thread_index,
               global_network,
               initial_learning_rate,
               learning_rate_input,
               grad_applier,
               env_type,
               env_name,
               use_lstm,
               use_pixel_change,
               use_value_replay,
               use_reward_prediction,
               pixel_change_lambda,
               entropy_beta,
               local_t_max,
               gamma,
               gamma_pc,
               experience_history_size,
               max_global_time_step,
               device):

    self.thread_index = thread_index
    self.learning_rate_input = learning_rate_input
    self.env_type = env_type
    self.env_name = env_name
    self.use_lstm = use_lstm
    self.use_pixel_change = use_pixel_change
    self.use_value_replay = use_value_replay
    self.use_reward_prediction = use_reward_prediction
    self.local_t_max = local_t_max
    self.gamma = gamma
    self.gamma_pc = gamma_pc
    self.experience_history_size = experience_history_size
    self.max_global_time_step = max_global_time_step
    self.action_size = Environment.get_action_size(env_type, env_name)
    self.objective_size = Environment.get_objective_size(env_type, env_name)
    
    self.local_network = UnrealModel(self.action_size,
                                     self.objective_size,
                                     thread_index,
                                     use_lstm,
                                     use_pixel_change,
                                     use_value_replay,
                                     use_reward_prediction,
                                     pixel_change_lambda,
                                     entropy_beta,
                                     device)
    self.local_network.prepare_loss()

    self.apply_gradients = grad_applier.minimize_local(self.local_network.total_loss,
                                                       global_network.get_vars(),
                                                       self.local_network.get_vars())
    
    self.sync = self.local_network.sync_from(global_network)
    self.experience = Experience(self.experience_history_size)
    self.local_t = 0
    self.initial_learning_rate = initial_learning_rate
    self.episode_reward = 0
    # For log output
    self.prev_local_t = 0
Пример #3
0
 def __init__(self):
     self.action_size = Environment.get_action_size(flags.env_type,
                                                    flags.env_name)
     self.objective_size = Environment.get_objective_size(
         flags.env_type, flags.env_name)
     self.global_network = UnrealModel(self.action_size,
                                       self.objective_size,
                                       -1,
                                       flags.use_lstm,
                                       flags.use_pixel_change,
                                       flags.use_value_replay,
                                       flags.use_reward_prediction,
                                       0.0,
                                       0.0,
                                       "/cpu:0",
                                       for_display=True)
     self.environment = Environment.create_environment(
         flags.env_type,
         flags.env_name,
         env_args={
             'episode_schedule': flags.split,
             'log_action_trace': flags.log_action_trace,
             'seed': flags.seed,
             # 'max_states_per_scene': flags.episodes_per_scene,
             'episodes_per_scene_test': flags.episodes_per_scene
         })
     self.episode_reward = 0
     self.cnt_success = 0
Пример #4
0
 def __init__(self):
     self.action_size = Environment.get_action_size(flags.env_type,
                                                    flags.env_name)
     self.objective_size = Environment.get_objective_size(
         flags.env_type, flags.env_name)
     print('flags:use_pixel_change {}'.format(flags.use_pixel_change))
     sleep(10)
     self.global_network = UnrealModel(self.action_size,
                                       self.objective_size,
                                       -1,
                                       flags.use_lstm,
                                       flags.use_pixel_change,
                                       flags.use_value_replay,
                                       flags.use_reward_prediction,
                                       0.0,
                                       0.0,
                                       "/cpu:0",
                                       for_display=True)
     self.environment = Environment.create_environment(
         flags.env_type,
         flags.env_name,
         env_args={
             'episode_schedule': flags.split,
             'log_action_trace': flags.log_action_trace,
             'max_states_per_scene': flags.episodes_per_scene,
             'episodes_per_scene_test': flags.episodes_per_scene
         })
     print('\n======\nENV in Evaluate::ctor')
     print(self.environment)
     print(self.global_network)
     print('val_replay!!! {}'.format(flags.use_value_replay))
     print(flags.split)
     print('=======\n')
     sleep(10)
     self.episode_reward = 0
Пример #5
0
    def __init__(self, display_size):
        pygame.init()

        self.surface = pygame.display.set_mode(display_size, 0, 24)
        name = 'UNREAL' if flags.segnet == 0 else "A3C ErfNet"
        pygame.display.set_caption(name)

        env_config = sim_config.get(flags.env_name)
        self.image_shape = [
            env_config.get('height', 88),
            env_config.get('width', 88)
        ]
        segnet_param_dict = {'segnet_mode': flags.segnet}
        is_training = tf.placeholder(tf.bool, name="training")
        map_file = env_config.get('objecttypes_file', '../../objectTypes.csv')
        self.label_mapping = pd.read_csv(map_file, sep=',', header=0)
        self.get_col_index()

        self.action_size = Environment.get_action_size(flags.env_type,
                                                       flags.env_name)
        self.objective_size = Environment.get_objective_size(
            flags.env_type, flags.env_name)
        self.global_network = UnrealModel(self.action_size,
                                          self.objective_size,
                                          -1,
                                          flags.use_lstm,
                                          flags.use_pixel_change,
                                          flags.use_value_replay,
                                          flags.use_reward_prediction,
                                          0.0,
                                          0.0,
                                          "/gpu:0",
                                          segnet_param_dict=segnet_param_dict,
                                          image_shape=self.image_shape,
                                          is_training=is_training,
                                          n_classes=flags.n_classes,
                                          segnet_lambda=flags.segnet_lambda,
                                          dropout=flags.dropout,
                                          for_display=True)
        self.environment = Environment.create_environment(
            flags.env_type,
            flags.env_name,
            flags.termination_time_sec,
            env_args={
                'episode_schedule': flags.split,
                'log_action_trace': flags.log_action_trace,
                'max_states_per_scene': flags.episodes_per_scene,
                'episodes_per_scene_test': flags.episodes_per_scene
            })
        self.font = pygame.font.SysFont(None, 20)
        self.value_history = ValueHistory()
        self.state_history = StateHistory()
        self.episode_reward = 0
Пример #6
0
def main(args):
    action_size = Environment.get_action_size(flags.env_type, flags.env_name)
    objective_size = Environment.get_objective_size(flags.env_type,
                                                    flags.env_name)
    global_network = UnrealModel(action_size, objective_size, -1,
                                 flags.use_lstm, flags.use_pixel_change,
                                 flags.use_value_replay,
                                 flags.use_reward_prediction, 0.0, 0.0,
                                 "/cpu:0")  # use CPU for weight visualize tool

    sess = tf.Session()

    init = tf.global_variables_initializer()
    sess.run(init)

    saver = tf.train.Saver()
    checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)
        print("checkpoint loaded:", checkpoint.model_checkpoint_path)
    else:
        print("Could not find old checkpoint")

    vars = {}
    var_list = global_network.get_vars()
    for v in var_list:
        vars[v.name] = v

    W_conv1 = sess.run(vars['net_-1/base_conv/W_base_conv1:0'])

    # show graph of W_conv1
    fig, axes = plt.subplots(3,
                             16,
                             figsize=(12, 6),
                             subplot_kw={
                                 'xticks': [],
                                 'yticks': []
                             })
    fig.subplots_adjust(hspace=0.1, wspace=0.1)

    for ax, i in zip(axes.flat, range(3 * 16)):
        inch = i // 16
        outch = i % 16
        img = W_conv1[:, :, inch, outch]
        ax.imshow(img, cmap=plt.cm.gray, interpolation='nearest')
        ax.set_title(str(inch) + "," + str(outch))

    plt.show()
Пример #7
0
  def __init__(self):
    self.action_size = Environment.get_action_size(flags.env_type, flags.env_name)
    self.objective_size = Environment.get_objective_size(flags.env_type, flags.env_name)

    env_config = sim_config.get(flags.env_name)
    self.image_shape = [env_config['height'], env_config['width']]
    segnet_param_dict = {'segnet_mode': flags.segnet}
    is_training = tf.placeholder(tf.bool, name="training") # for display param in UnrealModel says its value

    self.global_network = UnrealModel(self.action_size,
                                      self.objective_size,
                                      -1,
                                      flags.use_lstm,
                                      flags.use_pixel_change,
                                      flags.use_value_replay,
                                      flags.use_reward_prediction,
                                      0.0, #flags.pixel_change_lambda
                                      0.0, #flags.entropy_beta
                                      device,
                                      segnet_param_dict=segnet_param_dict,
                                      image_shape=self.image_shape,
                                      is_training=is_training,
                                      n_classes=flags.n_classes,
                                      segnet_lambda=flags.segnet_lambda,
                                      dropout=flags.dropout,
                                      for_display=True)
    self.environment = Environment.create_environment(flags.env_type, flags.env_name, flags.termination_time_sec,
                                                      env_args={'episode_schedule': flags.split,
                                                                'log_action_trace': flags.log_action_trace,
                                                                'max_states_per_scene': flags.episodes_per_scene,
                                                                'episodes_per_scene_test': flags.episodes_per_scene})

    self.global_network.prepare_loss()

    self.total_loss = []
    self.segm_loss = []
    self.episode_reward = [0]
    self.episode_roomtype = []
    self.roomType_dict  = {}
    self.segnet_class_dict = {}
    self.success_rate = []
    self.batch_size = 20
    self.batch_cur_num = 0
    self.batch_prev_num = 0
    self.batch_si = []
    self.batch_sobjT = []
    self.batch_a = []
    self.batch_reward = []
Пример #8
0
    def __init__(self, thread_index, global_network, initial_learning_rate,
                 learning_rate_input, grad_applier, env_type, env_name,
                 use_lstm, use_pixel_change, use_value_replay,
                 use_reward_prediction, pixel_change_lambda, entropy_beta,
                 local_t_max, n_step_TD, gamma, gamma_pc,
                 experience_history_size, max_global_time_step, device,
                 segnet_param_dict, image_shape, is_training, n_classes,
                 random_state, termination_time, segnet_lambda, dropout):

        self.thread_index = thread_index
        self.learning_rate_input = learning_rate_input
        self.env_type = env_type
        self.env_name = env_name
        self.use_lstm = use_lstm
        self.use_pixel_change = use_pixel_change
        self.use_value_replay = use_value_replay
        self.use_reward_prediction = use_reward_prediction
        self.local_t_max = local_t_max
        self.n_step_TD = n_step_TD
        self.gamma = gamma
        self.gamma_pc = gamma_pc
        self.experience_history_size = experience_history_size
        self.max_global_time_step = max_global_time_step
        self.action_size = Environment.get_action_size(env_type, env_name)
        self.objective_size = Environment.get_objective_size(
            env_type, env_name)

        self.segnet_param_dict = segnet_param_dict
        self.segnet_mode = self.segnet_param_dict.get("segnet_mode", None)

        self.is_training = is_training
        self.n_classes = n_classes
        self.segnet_lambda = segnet_lambda

        self.run_metadata = tf.RunMetadata()
        self.many_runs_timeline = TimeLiner()

        self.random_state = random_state
        self.termination_time = termination_time
        self.dropout = dropout

        try:
            self.local_network = UnrealModel(
                self.action_size,
                self.objective_size,
                thread_index,
                use_lstm,
                use_pixel_change,
                use_value_replay,
                use_reward_prediction,
                pixel_change_lambda,
                entropy_beta,
                device,
                segnet_param_dict=self.segnet_param_dict,
                image_shape=image_shape,
                is_training=is_training,
                n_classes=n_classes,
                segnet_lambda=self.segnet_lambda,
                dropout=dropout)

            self.local_network.prepare_loss()

            self.apply_gradients = grad_applier.minimize_local(
                self.local_network.total_loss, global_network.get_vars(),
                self.local_network.get_vars(), self.thread_index)

            self.sync = self.local_network.sync_from(global_network)
            self.experience = Experience(self.experience_history_size,
                                         random_state=self.random_state)
            self.local_t = 0
            self.initial_learning_rate = initial_learning_rate
            self.episode_reward = 0
            # For log output
            self.prev_local_t = -1
            self.prev_local_t_loss = 0
            self.sr_size = 50
            self.success_rates = deque(maxlen=self.sr_size)
        except Exception as e:
            print(str(e))  #, flush=True)
            raise Exception(
                "Problem in Trainer {} initialization".format(thread_index))
Пример #9
0
    def run(self):
        device = "/cpu:0"
        if USE_GPU:
            device = "/gpu:0"

        initial_learning_rate = log_uniform(flags.initial_alpha_low,
                                            flags.initial_alpha_high,
                                            flags.initial_alpha_log_rate)

        self.global_t = 0

        self.stop_requested = False
        self.terminate_reqested = False

        action_size = Environment.get_action_size(flags.env_type,
                                                  flags.env_name)
        objective_size = Environment.get_objective_size(
            flags.env_type, flags.env_name)

        self.global_network = UnrealModel(
            action_size, objective_size, -1, flags.use_lstm,
            flags.use_pixel_change, flags.use_value_replay,
            flags.use_reward_prediction, flags.pixel_change_lambda,
            flags.entropy_beta, device)
        self.trainers = []

        learning_rate_input = tf.placeholder("float")

        grad_applier = RMSPropApplier(learning_rate=learning_rate_input,
                                      decay=flags.rmsp_alpha,
                                      momentum=0.0,
                                      epsilon=flags.rmsp_epsilon,
                                      clip_norm=flags.grad_norm_clip,
                                      device=device)

        for i in range(flags.parallel_size):
            trainer = Trainer(i, self.global_network, initial_learning_rate,
                              learning_rate_input, grad_applier,
                              flags.env_type, flags.env_name, flags.use_lstm,
                              flags.use_pixel_change, flags.use_value_replay,
                              flags.use_reward_prediction,
                              flags.pixel_change_lambda, flags.entropy_beta,
                              flags.local_t_max, flags.gamma, flags.gamma_pc,
                              flags.experience_history_size,
                              flags.max_time_step, device)
            self.trainers.append(trainer)

        # prepare session
        config = tf.ConfigProto(log_device_placement=False,
                                allow_soft_placement=True)
        config.gpu_options.allow_growth = True
        self.sess = tf.Session(config=config)

        self.sess.run(tf.global_variables_initializer())

        # summary for tensorboard
        self.score_input = tf.placeholder(tf.int32)
        tf.summary.scalar("score", self.score_input)

        self.summary_op = tf.summary.merge_all()
        self.summary_writer = tf.summary.FileWriter(flags.log_dir,
                                                    self.sess.graph)

        # init or load checkpoint with saver
        self.saver = tf.train.Saver(self.global_network.get_vars(),
                                    max_to_keep=0)

        checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir)
        if checkpoint and checkpoint.model_checkpoint_path:
            self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
            print("checkpoint loaded:", checkpoint.model_checkpoint_path)
            tokens = checkpoint.model_checkpoint_path.split("-")
            # set global step
            self.global_t = int(tokens[1])
            print(">>> global step set: ", self.global_t)
            # set wall time
            wall_t_fname = flags.checkpoint_dir + '/' + 'wall_t.' + str(
                self.global_t)
            with open(wall_t_fname, 'r') as f:
                self.wall_t = float(f.read())
                self.next_save_steps = (
                    self.global_t + flags.save_interval_step
                ) // flags.save_interval_step * flags.save_interval_step

        else:
            print("Could not find old checkpoint")
            # set wall time
            self.wall_t = 0.0
            self.next_save_steps = flags.save_interval_step

        if flags.pretrain_dir != "":
            checkpoint = tf.train.get_checkpoint_state(flags.pretrain_dir)
            if checkpoint and checkpoint.model_checkpoint_path:
                print("restore pretrained model in {}".format(
                    flags.pretrain_dir))
                self.saver.restore(self.sess, checkpoint.model_checkpoint_path)

        # run training threads
        self.train_threads = []
        for i in range(flags.parallel_size):
            self.train_threads.append(
                threading.Thread(target=self.train_function, args=(i, True)))

        signal.signal(signal.SIGINT, self.signal_handler)

        # set start time
        self.start_time = time.time() - self.wall_t

        for t in self.train_threads:
            t.start()

        print('Press Ctrl+C to stop')
        signal.pause()
Пример #10
0
  def run(self):
    device = "/cpu:0"
    if USE_GPU:
      device = "/gpu:0"

    self.print_flags_info()

    if flags.segnet == -1:
      with open(flags.segnet_config) as f:
        self.config = json.load(f)

      self.num_classes = self.config["NUM_CLASSES"]
      self.use_vgg = self.config["USE_VGG"]

      if self.use_vgg is False:
        self.vgg_param_dict = None
        print("No VGG path in config, so learning from scratch")
      else:
        self.vgg16_npy_path = self.config["VGG_FILE"]
        self.vgg_param_dict = np.load(self.vgg16_npy_path, encoding='latin1').item()
        print("VGG parameter loaded")

      self.bayes = self.config["BAYES"]
      segnet_param_dict = {'segnet_mode': flags.segnet, 'vgg_param_dict': self.vgg_param_dict, 'use_vgg': self.use_vgg,
                       'num_classes': self.num_classes, 'bayes': self.bayes}
    else: # 0, 1, 2, 3
      segnet_param_dict = {'segnet_mode': flags.segnet}

    if flags.env_type != 'indoor':
        env_config = {}
    else:
        env_config = sim_config.get(flags.env_name)
    self.image_shape = [env_config.get('height', 84), env_config.get('width', 84)]
    self.map_file = env_config.get('objecttypes_file', '../../objectTypes_1x.csv')
    
    initial_learning_rate = log_uniform(flags.initial_alpha_low,
                                        flags.initial_alpha_high,
                                        flags.initial_alpha_log_rate)
    self.global_t = 0
    
    self.stop_requested = False
    self.terminate_requested = False
    
    action_size = Environment.get_action_size(flags.env_type,
                                              flags.env_name)
    objective_size = Environment.get_objective_size(flags.env_type, flags.env_name)

    is_training = tf.placeholder(tf.bool, name="training")

    self.random_state = np.random.RandomState(seed=env_config.get("seed", 0xA3C))

    print("Global network initializing!")#, flush=True)

    self.global_network = UnrealModel(action_size,
                                      objective_size,
                                      -1,
                                      flags.use_lstm,
                                      flags.use_pixel_change,
                                      flags.use_value_replay,
                                      flags.use_reward_prediction,
                                      flags.pixel_change_lambda,
                                      flags.entropy_beta,
                                      device,
                                      segnet_param_dict=segnet_param_dict,
                                      image_shape=self.image_shape,
                                      is_training=is_training,
                                      n_classes=flags.n_classes,
                                      segnet_lambda=flags.segnet_lambda,
                                      dropout=flags.dropout)
    self.trainers = []
    
    learning_rate_input = tf.placeholder("float")

    
    grad_applier = RMSPropApplier(learning_rate = learning_rate_input,
                                  decay = flags.rmsp_alpha,
                                  momentum = 0.0,
                                  epsilon = flags.rmsp_epsilon,
                                  clip_norm = flags.grad_norm_clip,
                                  device = device)
    
    for i in range(flags.parallel_size):
      trainer = Trainer(i,
                        self.global_network,
                        initial_learning_rate,
                        learning_rate_input,
                        grad_applier,
                        flags.env_type,
                        flags.env_name,
                        flags.use_lstm,
                        flags.use_pixel_change,
                        flags.use_value_replay,
                        flags.use_reward_prediction,
                        flags.pixel_change_lambda,
                        flags.entropy_beta,
                        flags.local_t_max,
                        flags.n_step_TD,
                        flags.gamma,
                        flags.gamma_pc,
                        flags.experience_history_size,
                        flags.max_time_step,
                        device,
                        segnet_param_dict=segnet_param_dict,
                        image_shape=self.image_shape,
                        is_training=is_training,
                        n_classes = flags.n_classes,
                        random_state=self.random_state,
                        termination_time=flags.termination_time_sec,
                        segnet_lambda=flags.segnet_lambda,
                        dropout=flags.dropout)
      self.trainers.append(trainer)


    self.last_scores = []
    self.best_score = -1.0
    
    # prepare session
    config = tf.ConfigProto(allow_soft_placement = True, log_device_placement = False)
    config.gpu_options.allow_growth = True
    self.sess = tf.Session(config=config)

    # Wrap sess.run for debugging messages!
    def run_(*args, **kwargs):
      #print(">>> RUN!", args[0] if args else None)#, flush=True)
      return self.sess.__run(*args, **kwargs)  # getattr(self, "__run")(self, *args, **kwargs)
    self.sess.__run, self.sess.run = self.sess.run, run_

    self.sess.run(tf.global_variables_initializer())

    # summary for tensorboard
    self.score_input = tf.placeholder(tf.float32)
    self.sr_input = tf.placeholder(tf.float32)
    self.mIoU_input = tf.placeholder(tf.float32)
    self.term_global_t = tf.placeholder(tf.int32)

    self.losses_input = {}

    self.total_loss = tf.placeholder(tf.float32,  name='total_loss')
    self.losses_input.update({'all/total_loss': self.total_loss})

    self.base_loss = tf.placeholder(tf.float32, name='base_loss')
    self.losses_input.update({'all/base_loss': self.base_loss})

    self.policy_loss = tf.placeholder(tf.float32,  name='policy_loss')
    self.losses_input.update({'all/policy_loss': self.policy_loss})

    self.value_loss = tf.placeholder(tf.float32, name='policy_loss')
    self.losses_input.update({'all/value_loss': self.value_loss})

    self.grad_norm = tf.placeholder(tf.float32, name='grad_norm')
    self.losses_input.update({'all/loss/grad_norm': self.grad_norm})

    self.entropy_input = tf.placeholder(tf.float32, shape=[None], name='entropy')

    if segnet_param_dict["segnet_mode"] >= 2:
      self.decoder_loss = tf.placeholder(tf.float32,  name='decoder_loss')
      self.losses_input.update({'all/decoder_loss': self.decoder_loss})
      self.l2_weights_loss = tf.placeholder(tf.float32, name='regul_weights_loss')
      self.losses_input.update({'all/l2_weights_loss': self.l2_weights_loss})
    if flags.use_pixel_change:
      self.pc_loss = tf.placeholder(tf.float32,  name='pc_loss')
      self.losses_input.update({'all/pc_loss': self.pc_loss})
    if flags.use_value_replay:
      self.vr_loss = tf.placeholder(tf.float32,  name='vr_loss')
      self.losses_input.update({'all/vr_loss': self.vr_loss})
    if flags.use_reward_prediction:
      self.rp_loss = tf.placeholder(tf.float32,  name='rp_loss')
      self.losses_input.update({'all/rp_loss': self.rp_loss})

    score_summary = tf.summary.scalar("all/eval/score", self.score_input)
    sr_summary = tf.summary.scalar("all/eval/success_rate", self.sr_input)
    term_summary = tf.summary.scalar("all/eval/term_global_t", self.term_global_t)
    eval_summary = tf.summary.scalar("all/eval/mIoU_all", self.mIoU_input)
    losses_summary_list = []
    for key, val in self.losses_input.items():
      losses_summary_list += [tf.summary.scalar(key, val)]


    self.summary_op_dict = {'score_input': score_summary, 'eval_input': eval_summary, 'sr_input':sr_summary,
                            'losses_input': tf.summary.merge(losses_summary_list),
                            'entropy': tf.summary.scalar('all/eval/entropy_stepTD', tf.reduce_mean(self.entropy_input)),
                            'term_global_t': term_summary}
    flags.checkpoint_dir = os.path.join(base_dir, flags.checkpoint_dir)
    #print("First dirs {}::{}".format(flags.log_dir, flags.checkpoint_dir))
    flags.checkpoint_dir = flags.checkpoint_dir
    print("Checkpoint dir: {}, Log dir: {}".format(flags.checkpoint_dir, flags.log_dir))
    overall_FW = tf.summary.FileWriter(os.path.join(flags.log_dir, 'overall'),
                          self.sess.graph)
    self.summary_writer = [(tf.summary.FileWriter(os.path.join(flags.log_dir, 'worker_{}'.format(i)),
                                                self.sess.graph), overall_FW) for i in range(flags.parallel_size)]
    
    # init or load checkpoint with saver
    self.saver = tf.train.Saver(self.global_network.get_global_vars(), max_to_keep=20)


    
    #checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir, latest_filename ="best-checkpoint")
    #if checkpoint is None or checkpoint.model_checkpoint_path is None:
    #  checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir)
    checkpoint = tf.train.get_checkpoint_state(flags.checkpoint_dir)

    if checkpoint and checkpoint.model_checkpoint_path:
      if flags.segnet == -1:
          from tensorflow.python import pywrap_tensorflow
          reader = pywrap_tensorflow.NewCheckpointReader(checkpoint.model_checkpoint_path)
          big_var_to_shape_map = reader.get_variable_to_shape_map()
          s = []
          for key in big_var_to_shape_map:
              s += [key]
              # print("tensor_name: ", key)
          glob_var_names = [v.name for v in tf.global_variables()]
          endings = [r.split('/')[-1][:-2] for r in glob_var_names]
          old_ckpt_to_new_ckpt = {[k for k in s if endings[i] in k][0]: v for i, v in enumerate(tf.global_variables())}
          saver1 = tf.train.Saver(var_list=old_ckpt_to_new_ckpt)
          saver1.restore(self.sess, checkpoint.model_checkpoint_path)
      else:
          self.saver.restore(self.sess, checkpoint.model_checkpoint_path)
      print("checkpoint loaded:", checkpoint.model_checkpoint_path)
      tokens = checkpoint.model_checkpoint_path.split("-")
      # set global step
      if 'best' in checkpoint.model_checkpoint_path:
          files = os.listdir(flags.checkpoint_dir)
          max_g_step = 0
          max_best_score = -10
          for file in files:
            if '.meta' not in file or 'checkpoint' not in file:
              continue
            if len(tokens) == 2:
              continue
            if len(tokens) > 3:
              best_score = float('-0.'+file.split('-')[2]) if 'best' in file else float('-0.'+file.split('-')[1])
              if best_score > max_best_score:
                g_step = int(file.split('-')[3]).split('.')[0] if 'best' in file else int(file.split('-')[2].split('.')[0])
                if max_g_step < g_step:
                  max_g_step = g_step
            else:
              self.best_score = -1.0
              g_step = int(file.split('-')[2]) if 'best' in file else int(file.split('-')[1])
              if max_g_step < g_step:
                max_g_step = g_step
          self.best_score = max_best_score
          self.global_t = max_g_step
          print("Chosen g_step >>", g_step)
      else:
        if len(tokens) == 3:
          self.global_t = int(tokens[2])
        else:
          self.global_t = int(tokens[1])
      #for i in range(flags.parallel_size):
      #  self.trainers[i].local_t = self.global_t
      print(">>> global step set: ", self.global_t)
      # set wall time
      wall_t_fname = flags.checkpoint_dir + '/' + 'wall_t.' + str(self.global_t)
      with open(wall_t_fname, 'r') as f:
        self.wall_t = float(f.read())
        self.next_save_steps = (self.global_t + flags.save_interval_step) // flags.save_interval_step * flags.save_interval_step
      print_tensors_in_checkpoint_file(file_name=checkpoint.model_checkpoint_path,
                                     tensor_name='', all_tensors=False, all_tensor_names=True)
    else:
      print("Could not find old checkpoint")
      # set wall time
      self.wall_t = 0.0
      self.next_save_steps = flags.save_interval_step
    print("Global step {}, max best score {}".format(self.global_t, self.best_score))

    if flags.segnet_pretrain:
      checkpoint_dir = "../erfnet_segmentation/models"
      checkpoint_dir = os.path.join(checkpoint_dir, "aug_erfnetC_0_{}x{}_{}x/snapshots_best".format(
          self.image_shape[1],
          self.image_shape[0],
          self.map_file.split('_')[1].split('x')[0]))
      checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)

      big_weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='net_-1/base_encoder')
      big_weights += tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='net_-1/base_decoder')

      erfnet_weights = [l.name.split(':')[0].rsplit('net_-1/base_encoder/')[-1] for l in big_weights
             if len(l.name.split(':')[0].rsplit('net_-1/base_encoder/')) == 2]
      erfnet_weights += [l.name.split(':')[0].rsplit('net_-1/base_decoder/')[-1] for l in big_weights
              if len(l.name.split(':')[0].rsplit('net_-1/base_decoder/')) == 2]

      if checkpoint and checkpoint.model_checkpoint_path:
          saver2 = tf.train.Saver(var_list=dict(zip(erfnet_weights, big_weights)))
          saver2.restore(self.sess, checkpoint.model_checkpoint_path)
          print("ErfNet pretrained weights restored from file ", checkpoint_dir)
      else:
          print("Can't load pretrained weights for ErfNet from file ", checkpoint_dir)

    # run training threads
    self.train_threads = []
    for i in range(flags.parallel_size):
      self.train_threads.append(threading.Thread(target=self.train_function, args=(i,True)))
      
    signal.signal(signal.SIGINT, self.signal_handler)
  
    # set start time
    self.start_time = time.time() - self.wall_t

    print("Ready to start")
    for t in self.train_threads:
      t.start()
  
    print('Press Ctrl+C to stop')#, flush=True)
    signal.pause()