def train_cresi(config, paths, fn_mapping, image_suffix, folds_file_loc,
                save_path, log_path, num_channels=3, logger=None, num_workers=0):
    #t0 = time.time()
    print('paths', len(paths), paths)
    ds = ReadingImageProvider(RawImageType, paths, fn_mapping, 
                              image_suffix=image_suffix, num_channels=num_channels)
    if logger:
        logger.info("len ds: {}".format(len(ds)))
        logger.info("folds_file_loc: {}".format(folds_file_loc))
        logger.info("save_path: {}".format(save_path))
    else:
        print("len ds:", len(ds))
        print("folds_file_loc:", folds_file_loc)
        print("save_path:", save_path)

    folds = get_csv_folds(folds_file_loc, ds.im_names)
    print(folds)
    for fold, (train_idx, val_idx) in enumerate(folds):
        if args.fold is not None and int(args.fold) != fold:
            continue
        if logger:
            logger.info("num workers: {}".format(num_workers))
            logger.info("fold: {}".format(fold))
            # logger.info("(train_idx, val_idx):", (train_idx, val_idx))
            logger.info("len(train_idx): {}".format(len(train_idx)))
            logger.info("len(val_idx): {}".format(len(val_idx)))

        if config.num_channels == 3:
            transforms = get_flips_colors_augmentation()
        else:
            # can't do hsv rescaling with multiband imagery, so skip this part
            transforms = get_flips_shifts_augmentation()
    
        train(ds, fold, train_idx, val_idx, config, save_path, log_path,
              num_workers=num_workers, transforms=transforms )
    path_masks_train = '/wdata/train/masks_binned'
    path_images_train = '/wdata/train/psms' 

    paths = {
            'masks': path_masks_train,
            'images': path_images_train
            }
    
    fn_mapping = {
        'masks': lambda name: os.path.splitext(name)[0].replace('PS-MS', 'PS-RGB') + '.tif'  #'.png'
    }


    ds = ReadingImageProvider(RawImageType, paths, fn_mapping,
                              image_suffix='', num_channels=8)

    train_idx = np.arange(len(os.listdir(paths['images'])))
    transforms = get_flips_shifts_augmentation()

    train_loader = TrainDataset(ds, train_idx, transforms=transforms)

    b = 15
    gt = time.time()
    for m in range(2):
        st = time.time()
        for i, t in enumerate(train_loader):
            if i == b:
                break

    print('Time to iterate:', time.time() - gt, 'seconds')