Beispiel #1
0
    def __init__(self,
                 name,
                 input_shape,
                 n_out,
                 epsW=0.001,
                 epsB=0.002,
                 initW=0.01,
                 initB=0.0,
                 momW=0.0,
                 momB=0.0,
                 wc=0.0,
                 dropRate=0.0,
                 weight=None,
                 bias=None,
                 weightIncr=None,
                 biasIncr=None):
        self.inputShape = input_shape
        self.inputSize, self.batchSize = input_shape

        self.outputSize = n_out
        self.dropRate = dropRate

        self.weightShape = (self.outputSize, self.inputSize)
        self.biasShape = (self.outputSize, 1)
        WeightedLayer.__init__(self, name, 'fc', epsW, epsB, initW, initB,
                               momW, momB, wc, weight, bias, weightIncr,
                               biasIncr, self.weightShape, self.biasShape)
        util.log('%s dropRate: %s', self.name, self.dropRate)
Beispiel #2
0
  def predict(self, save_layers = None, filename = None):
    self.net.save_layerouput(save_layers)
    self.print_net_summary()
    util.log('Starting predict...')
    save_output = []
    while self.curr_epoch < 2:
      start = time.time()
      self.test_data = self.test_dp.get_next_batch()
      self.curr_epoch = self.test_data.epoch
      self.curr_batch = self.test_data.batchnum

      self.num_test_minibatch = divup(self.test_data.data.shape[1], self.batch_size)
      for i in range(self.num_test_minibatch):
        input, label = self.get_next_minibatch(i, TEST)
        self.net.train_batch(input, label, TEST)
      cost , correct, numCase = self.net.get_batch_information()
      print >> sys.stderr,  '%d.%d: error: %f logreg: %f time: %f' % (self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start)
      if save_layers is not None:
        save_output.extend(self.net.get_save_output())

    if save_layers is not None:
      if filename is not None:
        with open(filename, 'w') as f:
          cPickle.dump(save_output, f, protocol = -1)
        util.log('save layer output finished')
Beispiel #3
0
    def predict(self, save_layers=None, filename=None):
        self.net.save_layerouput(save_layers)
        self.print_net_summary()
        util.log('Starting predict...')
        save_output = []
        while self.curr_epoch < 2:
            start = time.time()
            test_data = self.test_dp.get_next_batch(self.batch_size)

            input, label = test_data.data, test_data.labels
            self.net.train_batch(input, label, TEST)
            cost, correct, numCase = self.net.get_batch_information()
            self.curr_epoch = self.test_data.epoch
            self.curr_batch += 1
            print >> sys.stderr, '%d.%d: error: %f logreg: %f time: %f' % (
                self.curr_epoch, self.curr_batch, 1 - correct, cost,
                time.time() - start)
            if save_layers is not None:
                save_output.extend(self.net.get_save_output())

        if save_layers is not None:
            if filename is not None:
                with open(filename, 'w') as f:
                    cPickle.dump(save_output, f, protocol=-1)
                util.log('save layer output finished')
Beispiel #4
0
  def save_checkpoint(self):
    model = {}
    model['batchnum'] = self.train_dp.get_batch_num()
    model['epoch'] = self.num_epoch + 1
    model['layers'] = self.net.get_dumped_layers()

    model['train_outputs'] = self.train_outputs
    model['test_outputs'] = self.test_outputs

    dic = {'model_state': model, 'op':None}
    self.print_net_summary()
    
    if not os.path.exists(self.checkpoint_dir):
      os.system('mkdir -p \'%s\'' % self.checkpoint_dir)
    
    saved_filename = [f for f in os.listdir(self.checkpoint_dir) if self.regex.match(f)]
    for f in saved_filename:
      os.remove(os.path.join(self.checkpoint_dir, f))
    checkpoint_filename = "test%d-%d.%d" % (self.test_id, self.curr_epoch, self.curr_batch)
    checkpoint_file_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
    self.checkpoint_file = checkpoint_file_path
    print >> sys.stderr,  checkpoint_file_path
    with open(checkpoint_file_path, 'w') as f:
      cPickle.dump(dic, f, protocol=-1)
    util.log('save file finished')
Beispiel #5
0
  def __init__(self, learningRate, imgShape, numOutput, init_model):
    self.learningRate = learningRate
    self.batchSize, self.numColor, self.imgSize, _ = imgShape
    self.imgShapes = [imgShape]
    self.inputShapes = [(self.numColor * (self.imgSize ** 2), self.batchSize)]
    self.numOutput = numOutput
    self.layers = []
    self.outputs = []
    self.grads = []
    self.output = None
    self.save_layers = None
    self.save_output = []

    self.numCase = self.cost = self.correct = 0.0

    self.numConv = 0
    
    if 'model_state' in init_model:
      # Loading from a checkpoint
      add_layers(FastNetBuilder(), self, init_model['model_state']['layers'])
    elif is_cudaconvnet_config(init_model):
      # AlexK config file
      add_layers(CudaconvNetBuilder(), self, init_model)
    else:
      # FastNet config file
      add_layers(FastNetBuilder(), self, init_model)

    self.adjust_learning_rate(self.learningRate)

    util.log('Learning rates:')
    for l in self.layers:
      util.log('%s: %s %s', l.name, getattr(l, 'epsW', 0), getattr(l, 'epsB', 0))
Beispiel #6
0
  def train(self):
    self.print_net_summary()
    util.log('Starting training...')
    while self.should_continue_training():
      train_data = self.train_dp.get_next_batch(self.batch_size)

      self.curr_epoch = train_data.epoch
      self.curr_batch += 1

      start = time.time()
      input, label = train_data.data, train_data.labels
      self.net.train_batch(input, label)
      if self.should_capture_training_data():
        self._capture_training_data()

      cost , correct, numCase = self.net.get_batch_information()
      self.train_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)]
      print >> sys.stderr,  '%d.%d: error: %f logreg: %f time: %f' % (self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start)

      if self.check_test_data():
        self.get_test_error()

      if self.factor != 1.0 and self.check_adjust_lr():
        self.adjust_lr()

      if self.check_save_checkpoint():
        self.save_checkpoint()

    self.get_test_error()
    self.save_checkpoint()
    self.report()
    self._finished_training()
Beispiel #7
0
def test_imagenet_loader():
  df = data.ImageNetDataProvider('/ssd/nn-data/imagenet/', 
                                 batch_range=range(1000), 
                                 category_range=range(20),
                                 batch_size=512)
  util.log('Index: %s', df.curr_batch_index)
  util.log('%s', df._get_next_batch()['data'].shape)
  util.log('Index: %s', df.curr_batch_index) 
  util.log('%s', df._get_next_batch()['data'].shape)
  util.log('Index: %s', df.curr_batch_index) 
Beispiel #8
0
def test_imagenet_loader():
    df = data.ImageNetDataProvider('/ssd/nn-data/imagenet/',
                                   batch_range=range(1000),
                                   category_range=range(20),
                                   batch_size=512)
    util.log('Index: %s', df.curr_batch_index)
    util.log('%s', df._get_next_batch()['data'].shape)
    util.log('Index: %s', df.curr_batch_index)
    util.log('%s', df._get_next_batch()['data'].shape)
    util.log('Index: %s', df.curr_batch_index)
Beispiel #9
0
  def __init__(self, data_dir, batch_range=None, category_range=None, batch_size=128):
    ParallelDataProvider.__init__(self, data_dir, batch_range)
    self.img_size = 256
    self.border_size = 16
    self.inner_size = 224
    self.batch_size = batch_size

    # self.multiview = dp_params['multiview_test'] and test
    self.multiview = 0
    self.num_views = 5 * 2
    self.data_mult = self.num_views if self.multiview else 1

    self.buffer_idx = 0

    dirs = glob.glob(data_dir + '/n*')
    synid_to_dir = {}
    for d in dirs:
      synid_to_dir[basename(d)[1:]] = d

    if category_range is None:
      cat_dirs = dirs
    else:
      cat_dirs = []
      for i in category_range:
        synid = self.batch_meta['label_to_synid'][i]
        # util.log('Using category: %d, synid: %s, label: %s', i, synid, self.batch_meta['label_names'][i])
        cat_dirs.append(synid_to_dir[synid])

    self.images = []
    batch_dict = dict((k, k) for k in self.batch_range)

    for d in cat_dirs:
      imgs = [v for i, v in enumerate(glob.glob(d + '/*.jpg')) if i in batch_dict]
      self.images.extend(imgs)

    self.images = np.array(self.images)

    # build index vector into 'images' and split into groups of batch-size
    image_index = np.arange(len(self.images))
    np.random.shuffle(image_index)

    self.batches = np.array_split(image_index,
                                  util.divup(len(self.images), batch_size))

    self.batch_range = range(len(self.batches))

    util.log('Starting data provider with %d batches', len(self.batches))
    np.random.shuffle(self.batch_range)

    imagemean = cPickle.loads(open(data_dir + "image-mean.pickle").read())
    self.data_mean = (imagemean['data']
        .astype(np.single)
        .T
        .reshape((3, 256, 256))[:, self.border_size:self.border_size + self.inner_size, self.border_size:self.border_size + self.inner_size]
        .reshape((self.get_data_dims(), 1)))
Beispiel #10
0
 def dump(self, checkpoint, suffix):
   self.checkpoint = checkpoint
   saved_filename = [f for f in os.listdir(self.checkpoint_dir) if self.regex.match(f)]
   for f in saved_filename:
     os.remove(os.path.join(self.checkpoint_dir, f))
   checkpoint_filename = "test%d-%d" % (self.test_id, suffix)
   self.checkpoint_file = os.path.join(self.checkpoint_dir, checkpoint_filename)
   print >> sys.stderr,  self.checkpoint_file
   with open(self.checkpoint_file, 'w') as f:
     cPickle.dump(checkpoint, f, protocol=-1)
   util.log('save file finished')
Beispiel #11
0
 def dump(self, checkpoint, suffix):
     self.checkpoint = checkpoint
     saved_filename = [
         f for f in os.listdir(self.checkpoint_dir) if self.regex.match(f)
     ]
     for f in saved_filename:
         os.remove(os.path.join(self.checkpoint_dir, f))
     checkpoint_filename = "test%d-%d" % (self.test_id, suffix)
     self.checkpoint_file = os.path.join(self.checkpoint_dir,
                                         checkpoint_filename)
     print >> sys.stderr, self.checkpoint_file
     with open(self.checkpoint_file, 'w') as f:
         cPickle.dump(checkpoint, f, protocol=-1)
     util.log('save file finished')
Beispiel #12
0
  def train(self):
    self.print_net_summary()
    util.log('Starting training...')
    while self.should_continue_training():
      self.train_data = self.train_dp.get_next_batch()  # self.train_dp.wait()
      self.curr_epoch = self.train_data.epoch
      self.curr_batch = self.train_data.batchnum

      start = time.time()
      self.num_train_minibatch = divup(self.train_data.data.shape[1], self.batch_size)
      t = 0
      
      for i in range(self.num_train_minibatch):
        input, label = self.get_next_minibatch(i)
        stime = time.time()
        self.net.train_batch(input, label)
        self._capture_training_data()
        t += time.time() - stime
        self.curr_minibatch += 1

      cost , correct, numCase = self.net.get_batch_information()
      self.train_outputs += [({'logprob': [cost, 1 - correct]}, numCase, time.time() - start)]
      print >> sys.stderr,  '%d.%d: error: %f logreg: %f time: %f' % (self.curr_epoch, self.curr_batch, 1 - correct, cost, time.time() - start)

      self.num_batch += 1
      if self.check_test_data():
        print >> sys.stderr,  '---- test ----'
        self.get_test_error()
        print >> sys.stderr,  '------------'

      if self.factor != 1.0 and self.check_adjust_lr():
        print >> sys.stderr,  '---- adjust learning rate ----'
        self.net.adjust_learning_rate(self.factor)
        print >> sys.stderr,  '--------'

      if self.check_save_checkpoint():
        print >> sys.stderr,  '---- save checkpoint ----'
        self.save_checkpoint()
        print >> sys.stderr,  '------------'

      wait_time = time.time()

      #print 'waitting', time.time() - wait_time, 'secs to load'
      #print 'time to train a batch file is', time.time() - start)


    self.get_test_error()
    self.save_checkpoint()
    self.report()
    self._finished_training()
Beispiel #13
0
  def __init__(self, name, input_shape, n_out, epsW=0.001, epsB=0.002, initW=0.01, initB=0.0,
      momW=0.0, momB=0.0, wc=0.0, dropRate=0.0, weight=None, bias=None, weightIncr = None, biasIncr
      = None):
    self.inputShape = input_shape
    self.inputSize, self.batchSize = input_shape

    self.outputSize = n_out
    self.dropRate = dropRate

    self.weightShape = (self.outputSize, self.inputSize)
    self.biasShape = (self.outputSize, 1)
    WeightedLayer.__init__(self, name, 'fc', epsW, epsB, initW, initB, momW, momB, wc, weight,
        bias, weightIncr, biasIncr, self.weightShape, self.biasShape)
    util.log('%s dropRate: %s', self.name, self.dropRate)
Beispiel #14
0
  def flush(self):
    if self.sz == 0:
      return

    out = {}
    for k in self.data[0].keys():
      items = [d[k] for d in self.data]
      out[k] = np.concatenate(items, axis=0)
    
    with open('%s.%d' % (self.target_path, self.count), 'w') as f:
      cPickle.dump(out, f, -1)

    util.log('Wrote layer dump.')
    self.data = []    
    self.sz = 0
    self.count += 1
Beispiel #15
0
  def __init__(self, checkpoint_dir, test_id):
    self.checkpoint_dir = checkpoint_dir

    if not os.path.exists(self.checkpoint_dir):
      os.system('mkdir -p \'%s\'' % self.checkpoint_dir)

    self.test_id = test_id
    self.regex = re.compile('^test%d-(\d+)$' % self.test_id)

    cp_pattern = self.checkpoint_dir + '/test%d-*' % self.test_id
    cp_files = glob.glob(cp_pattern)

    if not cp_files:
      self.checkpoint = None
      self.checkpoint_file = None
    else:
      self.checkpoint_file = sorted(cp_files, key=os.path.getmtime)[-1]
      util.log('Loading from checkpoint file: %s', self.checkpoint_file)
      self.checkpoint = util.load(self.checkpoint_file)
Beispiel #16
0
    def __init__(self, checkpoint_dir, test_id):
        self.checkpoint_dir = checkpoint_dir

        if not os.path.exists(self.checkpoint_dir):
            os.system('mkdir -p \'%s\'' % self.checkpoint_dir)

        self.test_id = test_id
        self.regex = re.compile('^test%d-(\d+)$' % self.test_id)

        cp_pattern = self.checkpoint_dir + '/test%d-*' % self.test_id
        cp_files = glob.glob(cp_pattern)

        if not cp_files:
            self.checkpoint = None
            self.checkpoint_file = None
        else:
            self.checkpoint_file = sorted(cp_files, key=os.path.getmtime)[-1]
            util.log('Loading from checkpoint file: %s', self.checkpoint_file)
            self.checkpoint = util.load(self.checkpoint_file)
Beispiel #17
0
  def cut_off_chunk(self):
    if len(self.memory_chunk) == 0:
      util.log('There is no chunk to cut off')
      return

    size = 0
    for k, v, in self.memory_chunk[0].iteritems():
      size += self.memory_chunk[0][k].nbytes

    del self.memory_chunk[0]
    self.total_data_size -= size
    util.log('drop off the first memory chunk')
    util.log('droped chunk size:    %s', size)
    util.log('total data size:      %s', self.total_data_size)
Beispiel #18
0
    def cut_off_chunk(self):
        if len(self.memory_chunk) == 0:
            util.log('There is no chunk to cut off')
            return

        size = 0
        for k, v, in self.memory_chunk[0].iteritems():
            size += self.memory_chunk[0][k].nbytes

        del self.memory_chunk[0]
        self.total_data_size -= size
        util.log('drop off the first memory chunk')
        util.log('droped chunk size:    %s', size)
        util.log('total data size:      %s', self.total_data_size)
Beispiel #19
0
    def train(self):
        self.print_net_summary()
        util.log('Starting training...')
        while self.should_continue_training():
            train_data = self.train_dp.get_next_batch(self.batch_size)

            self.curr_epoch = train_data.epoch
            self.curr_batch += 1

            start = time.time()
            input, label = train_data.data, train_data.labels
            self.net.train_batch(input, label)
            if self.should_capture_training_data():
                self._capture_training_data()

            cost, correct, numCase = self.net.get_batch_information()
            self.train_outputs += [({
                'logprob': [cost, 1 - correct]
            }, numCase, time.time() - start)]
            print >> sys.stderr, '%d.%d: error: %f logreg: %f time: %f' % (
                self.curr_epoch, self.curr_batch, 1 - correct, cost,
                time.time() - start)

            if self.check_test_data():
                self.get_test_error()

            if self.factor != 1.0 and self.check_adjust_lr():
                self.adjust_lr()

            if self.check_save_checkpoint():
                self.save_checkpoint()

        self.get_test_error()
        self.save_checkpoint()
        self.report()
        self._finished_training()
Beispiel #20
0
  def __init__(self, target_path, max_mem_size = 500e5):
    self.target_path = target_path
    self.data = []
    self.sz = 0
    self.count = 0
    self.max_mem_size = max_mem_size

    util.log('dumper establised')
    util.log('target path:    %s', self.target_path)
    util.log('max_memory:     %s', self.max_mem_size)
Beispiel #21
0
    def __init__(self, target_path, max_mem_size=500e5):
        self.target_path = target_path
        self.data = []
        self.sz = 0
        self.count = 0
        self.max_mem_size = max_mem_size

        util.log('dumper establised')
        util.log('target path:    %s', self.target_path)
        util.log('max_memory:     %s', self.max_mem_size)
Beispiel #22
0
    def __init__(self, single_memory_size=50e6, total_memory_size=2e9):
        self.single_memory_size = single_memory_size
        self.total_memory_size = total_memory_size
        self.single_data_size = 0
        self.total_data_size = 0
        self.count = 0
        self.data = []
        self.memory_chunk = []

        util.log('memory data holder establised')
        util.log('total memory size:    %s', self.total_memory_size)
        util.log('single memory size:   %s', self.single_memory_size)
Beispiel #23
0
  def __init__(self, single_memory_size = 50e6, total_memory_size = 2e9):
    self.single_memory_size = single_memory_size
    self.total_memory_size = total_memory_size
    self.single_data_size = 0
    self.total_data_size = 0
    self.count = 0
    self.data = []
    self.memory_chunk = []

    util.log('memory data holder establised')
    util.log('total memory size:    %s', self.total_memory_size)
    util.log('single memory size:   %s', self.single_memory_size)
Beispiel #24
0
    def flush(self):
        if self.single_data_size == 0:
            return

        dic = {}
        for k in self.data[0].keys():
            items = [d[k] for d in self.data]
            dic[k] = np.concatenate(items, axis=0)

        self.memory_chunk.append(dic)

        util.log('add another memory chunk')
        util.log('memory chunk size:    %s', self.single_data_size)
        util.log('total data size:    %s', self.total_data_size)

        self.data = []
        self.single_data_size = 0
        self.count += 1
Beispiel #25
0
  def flush(self):
    if self.single_data_size == 0:
      return

    dic = {}
    for k in self.data[0].keys():
      items= [d[k] for d in self.data]
      dic[k] = np.concatenate(items, axis = 0)

    self.memory_chunk.append(dic)

    util.log('add another memory chunk')
    util.log('memory chunk size:    %s', self.single_data_size)
    util.log('total data size:    %s', self.total_data_size)

    self.data = []
    self.single_data_size = 0
    self.count += 1
Beispiel #26
0
 def _log(self, fmt, *args):
     util.log('%s :: %s', rank, fmt % args)
Beispiel #27
0
  init_model = cp_dumper.get_checkpoint()
  if init_model is None:
    init_model = parse_config_file(args.param_file)
  param_dict['init_model'] = init_model

  #create train dataprovider and test dataprovider
  dp_class = DataProvider.get_by_name(param_dict['data_provider'])
  train_dp = dp_class(param_dict['data_dir'], param_dict['train_range'])
  test_dp = dp_class(param_dict['data_dir'], param_dict['test_range'])
  param_dict['train_dp'] = train_dp
  param_dict['test_dp'] = test_dp


  #get all extra information
  param_dict['num_epoch'] = args.num_epoch
  num_batch = util.string_to_int_list(args.num_batch)
  if len(num_batch) == 1:
    param_dict['num_batch'] = num_batch[0]
  else:
    param_dict['num_batch'] = num_batch

  param_dict['num_group_list']  = util.string_to_int_list(args.num_group_list)
  param_dict['num_caterange_list'] = util.string_to_int_list(args.num_caterange_list)
  param_dict['output_dir'] = args.output_dir


  trainer = Trainer.get_trainer_by_name(trainer, param_dict)
  util.log('start to train...')
  trainer.train()
  #trainer.predict(['pool5'], 'image.opt')
Beispiel #28
0
    def __init__(self,
                 data_dir,
                 batch_range=None,
                 category_range=None,
                 batch_size=128):
        ParallelDataProvider.__init__(self, data_dir, batch_range)
        self.img_size = 256
        self.border_size = 16
        self.inner_size = 224
        self.batch_size = batch_size

        # self.multiview = dp_params['multiview_test'] and test
        self.multiview = 0
        self.num_views = 5 * 2
        self.data_mult = self.num_views if self.multiview else 1

        self.buffer_idx = 0

        dirs = glob.glob(data_dir + '/n*')
        synid_to_dir = {}
        for d in dirs:
            synid_to_dir[basename(d)[1:]] = d

        if category_range is None:
            cat_dirs = dirs
        else:
            cat_dirs = []
            for i in category_range:
                synid = self.batch_meta['label_to_synid'][i]
                # util.log('Using category: %d, synid: %s, label: %s', i, synid, self.batch_meta['label_names'][i])
                cat_dirs.append(synid_to_dir[synid])

        self.images = []
        batch_dict = dict((k, k) for k in self.batch_range)

        for d in cat_dirs:
            imgs = [
                v for i, v in enumerate(glob.glob(d + '/*.jpg'))
                if i in batch_dict
            ]
            self.images.extend(imgs)

        self.images = np.array(self.images)

        # build index vector into 'images' and split into groups of batch-size
        image_index = np.arange(len(self.images))
        np.random.shuffle(image_index)

        self.batches = np.array_split(image_index,
                                      util.divup(len(self.images), batch_size))

        self.batch_range = range(len(self.batches))

        util.log('Starting data provider with %d batches', len(self.batches))
        np.random.shuffle(self.batch_range)

        imagemean = cPickle.loads(open(data_dir + "image-mean.pickle").read())
        self.data_mean = (imagemean['data'].astype(np.single).T.reshape(
            (3, 256, 256))[:,
                           self.border_size:self.border_size + self.inner_size,
                           self.border_size:self.border_size +
                           self.inner_size].reshape((self.get_data_dims(), 1)))
Beispiel #29
0
  param_dict = {}
  param_dict['image_color'] = 3
  param_dict['test_id'] = args.test_id
  param_dict['data_dir'] = args.data_dir
  param_dict['data_provider'] = args.data_provider
  if args.data_provider.startswith('imagenet'):
    param_dict['image_size'] = 224
  elif args.data_provider.startswith('cifar'):
    param_dict['image_size'] = 32
  else:
    assert False, 'Unknown data_provider %s' % args.data_provider
 
  param_dict['train_range'] = util.string_to_int_list(args.train_range)
  param_dict['test_range'] = util.string_to_int_list(args.test_range)
  util.log('%s %s', args.test_range, param_dict['test_range'])
  param_dict['save_freq'] = args.save_freq
  param_dict['test_freq'] = args.test_freq
  param_dict['adjust_freq'] = args.adjust_freq
  factor = util.string_to_float_list(args.factor)
  if len(factor) == 1:
    param_dict['factor'] = factor[0]
  else:
    param_dict['factor'] = factor


  learning_rate = util.string_to_float_list(args.learning_rate)
  if len(learning_rate) == 1:
    param_dict['learning_rate'] = learning_rate[0]
  else:
    param_dict['learning_rate'] = learning_rate
Beispiel #30
0
 def _log(self, fmt, *args):
   util.log('%s :: %s', rank, fmt % args)
Beispiel #31
0
    #create the init_model
    init_model = cp_dumper.get_checkpoint()
    if init_model is None:
        init_model = parse_config_file(args.param_file)
    param_dict['init_model'] = init_model

    #create train dataprovider and test dataprovider
    dp_class = DataProvider.get_by_name(param_dict['data_provider'])
    train_dp = dp_class(param_dict['data_dir'], param_dict['train_range'])
    test_dp = dp_class(param_dict['data_dir'], param_dict['test_range'])
    param_dict['train_dp'] = train_dp
    param_dict['test_dp'] = test_dp

    #get all extra information
    param_dict['num_epoch'] = args.num_epoch
    num_batch = util.string_to_int_list(args.num_batch)
    if len(num_batch) == 1:
        param_dict['num_batch'] = num_batch[0]
    else:
        param_dict['num_batch'] = num_batch

    param_dict['num_group_list'] = util.string_to_int_list(args.num_group_list)
    param_dict['num_caterange_list'] = util.string_to_int_list(
        args.num_caterange_list)
    param_dict['output_dir'] = args.output_dir

    trainer = Trainer.get_trainer_by_name(trainer, param_dict)
    util.log('start to train...')
    trainer.train()
    #trainer.predict(['pool5'], 'image.opt')