base_datamgr            = SetDataManager(image_size, n_query = n_query, n_eposide = 100, **train_few_shot_params)
  test_few_shot_params    = dict(n_way = params.test_n_way, n_support = params.n_shot)
  val_datamgr             = SetDataManager(image_size, n_query = n_query, n_eposide = 100, **test_few_shot_params)
  val_loader              = val_datamgr.get_data_loader( val_file, aug = False)
  val_loader_nd           = val_datamgr.get_data_loader( val_file, aug = False)

  model = MAML(params, tf_path=params.tf_dir)
  model.cuda()

  # resume training
  start_epoch = params.start_epoch
  stop_epoch = params.stop_epoch
  if params.resume != '':
    resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)
    if resume_file is not None:
      start_epoch = model.resume(resume_file)
      print('  resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))
    else:
      raise ValueError('No resume file')
  # load pre-trained feature encoder
  else:
    if params.warmup == 'scratch':
      pass
    elif params.warmup == 'gg3b0':
      raise Exception('Must provide pre-trained feature-encoder file using --warmup option!')
    else:
      model.model.feature.load_state_dict(load_warmup_state('%s/checkpoints/%s'%(params.save_dir, params.warmup), params.method), strict=False)

  # training
  print('\n--- start the training ---')
  train(base_datamgr, datasets, val_loader, val_loader_nd, model, start_epoch, stop_epoch, params)
    # statics
    print('\n--- get statics ---')
    for i in range(n_exp):
        acc_all = np.empty((n_task, 2))
        base_data_loader = datamgr.get_data_loader(base_loadfile, aug=False)
        test_data_loader = datamgr.get_data_loader(test_loadfile, aug=False)

        base_data_generator = iter(base_data_loader)
        test_data_generator = iter(test_data_loader)

        task_pbar = tqdm(range(n_task))
        for j in task_pbar:
            base_task = next(base_data_generator)[0]
            test_task = next(test_data_generator)[0]
            n_sub_query = params.n_sub_query
            _ = model.resume(modelfile)
            if False:
                base_acc, test_acc = test_uni(base_task, test_task, model,
                                              n_iter, n_sub_query, params)
            else:
                base_acc, test_acc = test(base_task, test_task, model, n_iter,
                                          n_sub_query, params)
            acc_all[j] = [base_acc, test_acc]

        acc_mean = np.mean(acc_all, axis=0)
        acc_std = np.std(acc_all, axis=0)

        base_acc, test_acc = acc_mean
        base_ci, test_ci = 1.96 * acc_std / np.sqrt(n_task)

        print(