コード例 #1
0
  print('  train with model: %s'%params.model)
  if 'Conv' in params.model:
    image_size = 84
  else:
    image_size = 224

  n_query = max(1, int(16* params.test_n_way/params.train_n_way))
  train_few_shot_params   = dict(n_way = params.train_n_way, n_support = params.n_shot)
  base_datamgr            = SetDataManager(image_size, n_query = n_query, n_eposide = 100, **train_few_shot_params)
  test_few_shot_params    = dict(n_way = params.test_n_way, n_support = params.n_shot)
  val_datamgr             = SetDataManager(image_size, n_query = n_query, n_eposide = 100, **test_few_shot_params)
  val_loader              = val_datamgr.get_data_loader( val_file, aug = False)
  val_loader_nd           = val_datamgr.get_data_loader( val_file, aug = False)

  model = MAML(params, tf_path=params.tf_dir)
  model.cuda()

  # resume training
  start_epoch = params.start_epoch
  stop_epoch = params.stop_epoch
  if params.resume != '':
    resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)
    if resume_file is not None:
      start_epoch = model.resume(resume_file)
      print('  resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))
    else:
      raise ValueError('No resume file')
  # load pre-trained feature encoder
  else:
    if params.warmup == 'scratch':
      pass
コード例 #2
0
    # prepare model
    print('--- prepare model {} (backbone {}) ---'.format(
        params.method, params.model))
    if params.method in ['maml', 'maml_approx']:
        backbone.ConvBlock.maml = True
        backbone.SimpleBlock.maml = True
        backbone.BottleneckBlock.maml = True
        backbone.ResNet.maml = True
        model = MAML(model_dict[params.model],
                     approx=(params.method == 'maml_approx'),
                     tf_path=params.tf_dir,
                     **train_few_shot_params)
    else:
        raise ValueError('Unknown method')
    model = model.cuda()

    max_acc = 0
    total_it = 0
    start_epoch = params.start_epoch
    stop_epoch = params.stop_epoch
    if params.method not in ['baseline', 'baseline++']:
        stop_epoch = params.stop_epoch * model.batch_size

    # resume/warmup or not
    if params.resume != '':
        resume_file = get_resume_file('%s/checkpoints/%s' %
                                      (params.save_dir, params.resume))
        if resume_file is not None:
            print('  load resume file: {}'.format(params.resume))
            tmp = torch.load(resume_file)