示例#1
0
    def __init__(self, args):
        self.batch_size = cfg.TRAIN.CONFIG.BATCH_SIZE
        self.gpu_num = cfg.TRAIN.CONFIG.GPU_NUM
        self.num_workers = cfg.DATA_LOADER.NUM_THREADS
        self.log_dir = cfg.MODEL.PATH.EVALUATION_DIR
        self.is_training = False

        self.cls_thresh = float(args.cls_threshold)
        self.eval_interval_secs = args.eval_interval_secs
        self.restore_model_path = args.restore_model_path

        self.eval = True

        # save dir
        datetime_str = str(datetime.datetime.now())
        self.log_dir = os.path.join(self.log_dir, self.restore_model_path,
                                    datetime_str)
        if not os.path.exists(self.log_dir): os.makedirs(self.log_dir)
        self.log_file = open(os.path.join(self.log_dir, 'log_train.txt'), 'w')
        self.log_file.write(str(args) + '\n')
        self._log_string('**** Saving Evaluation results to the path %s ****' %
                         self.log_dir)

        # dataset
        dataset_func = choose_dataset()
        self.dataset = dataset_func('loading',
                                    split=args.split,
                                    img_list=args.img_list,
                                    is_training=self.is_training,
                                    workers_num=self.num_workers)
        self.dataset_iter = self.dataset.load_batch(self.batch_size *
                                                    self.gpu_num)
        self._log_string('**** Dataset length is %d ****' % len(self.dataset))
        self.val_size = len(self.dataset)

        # model list
        self.model_func = choose_model()
        self.model_list, self.pred_list, self.placeholders = self._build_model_list(
        )

        # feeddict
        self.feeddict_producer = FeedDictCreater(self.dataset_iter,
                                                 self.model_list,
                                                 self.batch_size)

        # evaluation tools
        self.last_eval_model_path = None
        self.last_best_model = None
        self.last_best_result = -1
        self.saver = tf.train.Saver()
示例#2
0
    def __init__(self, args):
        self.batch_size = cfg.TRAIN.CONFIG.BATCH_SIZE
        self.gpu_num = cfg.TRAIN.CONFIG.GPU_NUM
        self.num_workers = cfg.DATA_LOADER.NUM_THREADS
        self.log_dir = cfg.MODEL.PATH.CHECKPOINT_DIR
        self.max_iteration = cfg.TRAIN.CONFIG.MAX_ITERATIONS
        self.checkpoint_interval = cfg.TRAIN.CONFIG.CHECKPOINT_INTERVAL
        self.summary_interval = cfg.TRAIN.CONFIG.SUMMARY_INTERVAL
        self.trainable_param_prefix = cfg.TRAIN.CONFIG.TRAIN_PARAM_PREFIX
        self.trainable_loss_prefix = cfg.TRAIN.CONFIG.TRAIN_LOSS_PREFIX

        self.restore_model_path = args.restore_model_path
        self.is_training = True

        # gpu_num
        self.gpu_num = min(self.gpu_num, len(self._get_available_gpu_num()))

        # save dir
        datetime_str = str(datetime.datetime.now())
        self.log_dir = os.path.join(self.log_dir, datetime_str)
        if not os.path.exists(self.log_dir): os.makedirs(self.log_dir)
        self.log_file = open(os.path.join(self.log_dir, 'log_train.txt'), 'w')
        self.log_file.write(str(args) + '\n')
        self._log_string('**** Saving models to the path %s ****' %
                         self.log_dir)
        self._log_string('**** Saving configure file in %s ****' %
                         self.log_dir)
        os.system('cp \"%s\" \"%s\"' % (args.cfg, self.log_dir))

        # dataset
        dataset_func = choose_dataset()
        self.dataset = dataset_func('loading',
                                    split=args.split,
                                    img_list=args.img_list,
                                    is_training=self.is_training,
                                    workers_num=self.num_workers)
        self.dataset_iter = self.dataset.load_batch(self.batch_size *
                                                    self.gpu_num)
        self._log_string('**** Dataset length is %d ****' % len(self.dataset))

        # optimizer
        with tf.device('/cpu:0'):
            self.global_step = tf.contrib.framework.get_or_create_global_step()
            self.bn_decay = get_bn_decay(self.global_step)
            self.learning_rate = get_learning_rate(self.global_step)
            if cfg.SOLVER.TYPE == 'SGD':
                self.optimizer = tf.train.MomentumOptimizer(
                    self.learning_rate, momentum=cfg.SOLVER.MOMENTUM)
            elif cfg.SOLVER.TYPE == 'Adam':
                self.optimizer = tf.train.AdamOptimizer(self.learning_rate)

        # models
        self.model_func = choose_model()
        self.model_list, self.tower_grads, self.total_loss_gpu, self.losses_list, self.params, self.extra_update_ops = self._build_model_list(
        )
        tf.summary.scalar('total_loss', self.total_loss_gpu)

        # feeddict
        self.feeddict_producer = FeedDictCreater(self.dataset_iter,
                                                 self.model_list,
                                                 self.batch_size)

        # print(self.tower_grads)

        with tf.device('/gpu:0'):
            self.grads = average_gradients(self.tower_grads)
            self.update_op = [
                self.optimizer.apply_gradients(zip(self.grads, self.params),
                                               global_step=self.global_step)
            ]
        self.update_op.extend(self.extra_update_ops)
        self.train_op = tf.group(*self.update_op)

        # tensorflow training ops
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=1,
                                    allow_growth=True)
        config = tf.ConfigProto(
            gpu_options=gpu_options,
            device_count={
                "GPU": self.gpu_num,
            },
            allow_soft_placement=True,
        )
        self.sess = tf.Session(config=config)

        self.saver = tf.train.Saver()
        self.merged = tf.summary.merge_all()
        self.train_writer = tf.summary.FileWriter(
            os.path.join(self.log_dir, 'train'), self.sess.graph)

        # initialize model
        self._initialize_model()