Example #1
0
def main():
    config = get_config()

    dfs, classes = load_data(config)

    for name in dfs.keys():
        dfs[name] = dfs[name].sort_index()

    assert (dfs['test'].index == dfs['prediction'].index).all()

    dfs['test'] = preprocess.update_new_whales(df_train=dfs['train'], df_test=dfs['test'])

    assert (dfs['test'].index == dfs['prediction'].index).all()

    if config.known_only:
        dfs['test'], dfs['prediction'] = preprocess.remove_new_whales(df_test=dfs['test'], df_pred=dfs['prediction'])

        assert (dfs['test'].index == dfs['prediction'].index).all()

    converter = preprocess.Converter(classes)
    pred, actual = converter.to_numpy(dfs['prediction'], dfs['test'])

    report = ReportManager(config.name)

    report.set_info(solution=config.solution, description=config.description, known_only=config.known_only)

    report.add_metric('MAP@5', metrics.mapk(actual, pred, k=5))

    ks = (1, 3, 5)
    tops = metrics.precisionk(actual, pred, topk=ks)

    for k, top in zip(ks, tops):
        report.add_metric(f'Top@{k}', top)

    report.finish(config.output, save=True)
    def loss_func(weights):
        ''' scipy minimize will pass the weights as a numpy array '''
        weights /= np.sum(weights)
        # print("weights", weights)
        final_predict = np.zeros_like(train_predicts[0])

        for weight, prediction in zip(weights, train_predicts):
            final_predict += weight * prediction

        score = -mapk(torch.tensor(final_predict), torch.tensor(train_targets))
        print("score", score)
        return score
Example #3
0
    def loss_func(weights):
        ''' scipy minimize will pass the weights as a numpy array '''
        print("weights", weights)
        final_predict = np.zeros_like(train_predicts[0])

        for weight, prediction in zip(weights, train_predicts):
            # print("weight", weight, "prediction", prediction)
            final_predict += weight * prediction

        # print("final_predict", final_predict)
        # print("train_targets", train_targets.shape)
        score = mapk(torch.tensor(final_predict), torch.tensor(train_targets))
        print("score", score)
        return score
Example #4
0
def evaluate(model, data_loader, criterion, mapk_topk):
    model.eval()

    loss_sum_t = zero_item_tensor()
    mapk_sum_t = zero_item_tensor()
    accuracy_top1_sum_t = zero_item_tensor()
    accuracy_top3_sum_t = zero_item_tensor()
    accuracy_top5_sum_t = zero_item_tensor()
    accuracy_top10_sum_t = zero_item_tensor()
    step_count = 0

    with torch.no_grad():
        for batch in data_loader:
            images, categories, categories_one_hot = \
                batch[0].to(device, non_blocking=True), \
                batch[1].to(device, non_blocking=True), \
                batch[2].to(device, non_blocking=True)

            prediction_logits = model(images)
            # if prediction_logits.size(1) == len(class_weights):
            #     criterion.weight = class_weights
            loss = criterion(prediction_logits, get_loss_target(criterion, categories, categories_one_hot))

            num_categories = prediction_logits.size(1)

            loss_sum_t += loss
            mapk_sum_t += mapk(prediction_logits, categories, topk=min(mapk_topk, num_categories))
            accuracy_top1_sum_t += accuracy(prediction_logits, categories, topk=min(1, num_categories))
            accuracy_top3_sum_t += accuracy(prediction_logits, categories, topk=min(3, num_categories))
            accuracy_top5_sum_t += accuracy(prediction_logits, categories, topk=min(5, num_categories))
            accuracy_top10_sum_t += accuracy(prediction_logits, categories, topk=min(10, num_categories))

            step_count += 1

    loss_avg = loss_sum_t.item() / step_count
    mapk_avg = mapk_sum_t.item() / step_count
    accuracy_top1_avg = accuracy_top1_sum_t.item() / step_count
    accuracy_top3_avg = accuracy_top3_sum_t.item() / step_count
    accuracy_top5_avg = accuracy_top5_sum_t.item() / step_count
    accuracy_top10_avg = accuracy_top10_sum_t.item() / step_count

    return loss_avg, mapk_avg, accuracy_top1_avg, accuracy_top3_avg, accuracy_top5_avg, accuracy_top10_avg
Example #5
0
    def evaluate_query_set(self, query_set_dict, has_masks, distance_eq):
        gt = query_set_dict["gt"]
        predictions = []

        for file, im in query_set_dict["images"].items():

            mask = None
            if has_masks:
                mask_filename = os.path.join(
                    self.output_folder,
                    file.split("/")[-1].split(".")[0] + ".png",
                )
                im, mask = mask_background(im)
                gt_mask = query_set_dict["masks"][file.replace("jpg", "png")]
                self.calc_mask_metrics(gt_mask[..., 0] / 255, mask / 255)
                if self.opt.save:
                    save_mask(mask_filename, mask)

            fv = self.calc_FV_query(im, mask)
            distances = calculate_distances(self.feature_vector_protoypes, fv,
                                            distance_eq)

            predictions.append(list(distances.argsort()[:10]))

        if self.opt.save:
            save_predictions(
                os.path.join(
                    self.output_folder,
                    "result_{}.pkl".format(int(has_masks) + 1),
                ),
                predictions,
            )

        map_k = mapk(gt, predictions)

        return map_k
Example #6
0
def eval_set(loader, gt_correspondences, bbdd_fvs, opt):
    masks_metrics = {"precision": [], "recall": [], "f1": []}
    ious = []
    predictions = []
    set_bboxes = []
    for name, query_image, gt_mask in loader:
        if opt.apply_denoise:
            query_image, Noise_level_before, Noise_level_after, blur_type_last = detect_denoise(
                query_image, opt.blur_type)
        # transform to another color space
        multiple_painting, split_point, bg_mask = detect_paintings(query_image)
        bboxes, bbox_mask = detect_bboxes(query_image)
        res_mask = bg_mask.astype(bool) ^ bbox_mask.astype(
            bool) if loader.detect_bboxes else bg_mask
        if loader.compute_masks:
            if loader.evaluate:
                calc_mask_metrics(masks_metrics, gt_mask / 255, bg_mask)
            if opt.save:
                mask_name = name.split("/")[-1].replace(".jpg", ".png")
                save_mask(
                    os.path.join(opt.output,
                                 loader.root.split("/")[-1], mask_name),
                    res_mask * 255)

        # cropped sets, no need to mask image for retrieval
        if gt_mask is None:
            res_mask = None
        if loader.detect_bboxes:
            set_bboxes.append(bboxes)

        # change colorspace before computing feature vector
        query_image = transform_color(
            query_image, opt.color) if opt.color is not None else query_image
        if multiple_painting and gt_mask is not None:
            im_preds = []
            left_paint = np.zeros_like(res_mask)
            right_paint = np.zeros_like(res_mask)

            left_paint[:, split_point:] = res_mask[:, split_point:]
            right_paint[:, :split_point] = res_mask[:, :split_point]

            res_masks = [left_paint, right_paint]
            for submasks in res_masks:
                query_fv = calc_FV(query_image, opt, submasks).ravel()
                distances = calculate_distances(bbdd_fvs,
                                                query_fv,
                                                mode=opt.dist)
                im_preds.append((distances.argsort()[:10]).tolist())
            predictions.append(im_preds)

        else:
            query_fv = calc_FV(query_image, opt, res_mask).ravel()
            distances = calculate_distances(bbdd_fvs, query_fv, mode=opt.dist)

            predictions.append((distances.argsort()[:10]).tolist())

    if opt.save:
        save_predictions(
            "{}/{}/result.pkl".format(opt.output,
                                      loader.root.split("/")[-1]), predictions)
        save_predictions(
            "{}/{}/text_boxes.pkl".format(opt.output,
                                          loader.root.split("/")[-1]),
            set_bboxes)

    map_k = {
        i: mapk(gt_correspondences, predictions, k=i)
        for i in [10, 3, 1]
    } if loader.evaluate else None
    avg_mask_metrics = averge_masks_metrics(
        masks_metrics) if loader.evaluate else None

    return map_k, avg_mask_metrics
Example #7
0
grp_agg = grp_agg.groupby(['srch_destination_id',
                           'hotel_cluster']).sum().reset_index()
grp_agg['count'] = grp_agg['count'] - grp_agg['sum']

# count reduced by frequency of booking will give number of clicks
grp_agg = grp_agg.rename(columns={
    'sum': 'bookings',
    'count': 'clicks'
},
                         inplace=True)

# used Cross validation to find best estimate of click weight, which can be approximated to 0.30
grp_agg['relevance'] = grp_agg['bookings'] + click_rel * grp_agg['clicks']
most_rel = grp_agg.groupby(['srch_destination_id']).apply(most_relevant)
most_rel = pd.DataFrame(most_rel).rename(columns={0: 'hotel_cluster'},
                                         inplace=True)

test = test.merge(most_rel,
                  how='left',
                  left_on=['srch_destination_id'],
                  right_index=True)

# converting the prediction for MAPK input formats.
# MAPK requires input to be list of lists.
preds = []
for index, row in test.iterrows():
    preds.append(row['hotel_cluster_y'])
target = [[l] for l in test["hotel_cluster_x"]]

print "MAPK accuracy is", metrics.mapk(target, preds, k=5)
def evaluate_1_vs_all(train,
                      train_lbl,
                      test,
                      test_lbl,
                      n_eval_runs=10,
                      move_to_db=2,
                      k_list=[1, 5, 10]):
    """ Compute accuracy on each class from test set given the training set in multiple runs.
    Input:
    train: 2D numpy float array: array of embeddings for training set, shape = (num_train, len_emb)
    train_lbl: 1D numpy integer array: array of training labels, shape = (num_train,)
    test: 2D numpy float array: array of embeddings for test set, shape = (num_test, len_emb)
    test_lbl: 1D numpy integer array: array of test labels, shape = (num_test,)
    n_eval_runs: integer, number of evaluation runs,default = 10
    move_to_db: integer, number of images to move to a database for each individual, default = 2
    k: array of integers, top-k accuracy to evaluate.

    Returns:
    mean_accuracy_1, mean_accuracy_5, mean_accuracy_10
    """
    print('Computing top-k accuracy for k=', k_list)
    if isinstance(k_list, int):
        k_list = [k_list]
    # Auxilary function to flatten a list
    flatten = lambda l: [item for sublist in l for item in sublist]

    # Evaluate accuracy at different k over a multiple runs. Report average results.
    acc = {k: [] for k in k_list}
    map_dict = {k: [] for k in k_list}

    for i in range(n_eval_runs):
        neigh_lbl_run = []
        db_emb, db_lbl, query_emb, query_lbl = get_eval_set_one_class(
            train, train_lbl, test, test_lbl, move_to_db=move_to_db)
        print('Number of classes in query set: ', len(db_emb))

        for j in range(len(db_emb)):
            neigh_lbl_un, _, _ = predict_k_neigh(db_emb[j],
                                                 db_lbl[j],
                                                 query_emb[j],
                                                 k=10)
            neigh_lbl_run.append(neigh_lbl_un)

        query_lbl = flatten(query_lbl)
        neigh_lbl_run = flatten(neigh_lbl_run)

        # Calculate accuracy @k in a list of predictions
        for k in k_list:
            acc[k].append(acck(query_lbl, neigh_lbl_run, k=k, verbose=False))
            map_dict[k].append(mapk(query_lbl, neigh_lbl_run, k=k))

    # Report accuracy
    print('Accuracy over {} runs:'.format(n_eval_runs))
    acc_array = np.array([acc[k] for k in k_list], dtype=np.float32)
    acc_runs = np.mean(acc_array, axis=1) * 100
    std_runs = np.std(acc_array, axis=1) * 100
    print('Accuracy: ', acc_runs)
    print('Stdev: ', std_runs)
    for i, k in enumerate(k_list):
        print('ACC@{} %{:.2f} +-{:.2f}'.format(k, acc_runs[i], std_runs[i]))

    # Report Mean average precision at k
    print('MAP over {} runs:'.format(n_eval_runs))
    map_array = np.array([map_dict[k] for k in k_list], dtype=np.float32)
    map_runs = np.mean(map_array, axis=1) * 100
    std_map_runs = np.std(map_array, axis=1) * 100
    for i, k in enumerate(k_list):
        print('MAP@{} %{:.2f} +-{:.2f}'.format(k, map_runs[i],
                                               std_map_runs[i]))

    return dict(zip(k_list, acc_runs)), dict(zip(k_list, std_runs))
Example #9
0
        ndcg_orig = []
        ndcg_rnd = []
        gts = []
        res_orig = []
        res_pred = []
        res_rnd = []
        pool = Pool(40)
        for i in range(int(len(playlists_ids) / 1000)):
            results = pool.map(evaluate, range(i * 1000, (i + 1) * 1000))

            for rets_orig, rets_pred, gt in results:
                if len(gt) > 0:
                    res_orig.append(rets_orig)
                    res_pred.append(rets_pred)
                    rets_rnd = [
                        int(tr) for tr in rnd.sample(dict_test_ids.keys(), N)
                    ]
                    res_rnd.append(rets_rnd)
                    gts.append(gt)
                    ndcg_pred.append(metrics.ndcg(gt, rets_pred, N))
                    ndcg_orig.append(metrics.ndcg(gt, rets_orig, N))
                    ndcg_rnd.append(metrics.ndcg(gt, rets_rnd, N))
        print("MAP")
        print("PRED MAP@", N, ": ", metrics.mapk(gts, res_pred, N))
        print("ORIG MAP@", N, ": ", metrics.mapk(gts, res_orig, N))
        print("RND MAP@", N, ": ", metrics.mapk(gts, res_rnd, N))
        print("NDCG:")
        print("PRED: ", np.mean(ndcg_pred))
        print("ORIG: ", np.mean(ndcg_orig))
        print("RND: ", np.mean(ndcg_rnd))
Example #10
0
    if uid in ground_truth:
        overall_scores = [
            score_matrix[uid, lid] if (uid, lid) not in training_tuples else -1
            for lid in all_lids
        ]

        overall_scores = np.array(overall_scores)

        predicted = list(reversed(overall_scores.argsort()))[:100]
        actual = ground_truth[uid]

        # calculate the average of different k
        precision_5.append(precisionk(actual, predicted[:5]))
        recall_5.append(recallk(actual, predicted[:5]))
        nDCG_5.append(ndcgk(actual, predicted[:5]))
        MAP_5.append(mapk(actual, predicted[:5], 5))

        precision_10.append(precisionk(actual, predicted[:10]))
        recall_10.append(recallk(actual, predicted[:10]))
        nDCG_10.append(ndcgk(actual, predicted[:10]))
        MAP_10.append(mapk(actual, predicted[:10], 10))

        precision_20.append(precisionk(actual, predicted[:20]))
        recall_20.append(recallk(actual, predicted[:20]))
        nDCG_20.append(ndcgk(actual, predicted[:20]))
        MAP_20.append(mapk(actual, predicted[:20], 20))

        print(cnt, uid, "pre@10:", np.mean(precision_10), "rec@10:",
              np.mean(recall_10))

        rec_list.write('\t'.join(
def validate(submission, gt_submission):
    submission['Predicted'] = submission['Predicted'].apply(
        lambda x: list(map(int, x.split(" "))))
    gt_submission['Predicted'] = gt_submission['Predicted'].apply(
        lambda x: list(map(int, x.split(" "))))
    return mapk(gt_submission['Predicted'], submission['Predicted'], k=50)
    for cnt, uid in enumerate(all_uids):
        if uid in test_user:
            overall_scores = [all_user_check_in_score_list[uid][lid]
                              if (uid, lid) not in training_tuple else -1
                              for lid in all_lids]

            overall_scores = np.array(overall_scores)

            predicted = list(reversed(overall_scores.argsort()))[:100]
            actual = ground_truth[uid]

            # calculate the average of different k
            precision_5.append(precisionk(actual, predicted[:15]))
            recall_5.append(recallk(actual, predicted[:15]))
            nDCG_5.append(ndcgk(actual, predicted[:15]))
            MAP_5.append(mapk(actual, predicted[:15], 15))

            precision_10.append(precisionk(actual, predicted[:25]))
            recall_10.append(recallk(actual, predicted[:25]))
            nDCG_10.append(ndcgk(actual, predicted[:25]))
            MAP_10.append(mapk(actual, predicted[:25], 25))

            precision_20.append(precisionk(actual, predicted[:30]))
            recall_20.append(recallk(actual, predicted[:30]))
            nDCG_20.append(ndcgk(actual, predicted[:30]))
            MAP_20.append(mapk(actual, predicted[:30], 30))

            print(cnt, uid, "pre@10:", np.mean(precision_10), "rec@10:", np.mean(recall_10))

            rec_list.write('\t'.join([
                str(cnt),
Example #13
0
def main():
    args = argparser.parse_args()
    print("Arguments:")
    for arg in vars(args):
        print("  {}: {}".format(arg, getattr(args, arg)))
    print()

    input_dir = args.input_dir
    output_dir = args.output_dir
    base_model_dir = args.base_model_dir
    image_size = args.image_size
    augment = args.augment
    use_dummy_image = args.use_dummy_image
    use_progressive_image_sizes = args.use_progressive_image_sizes
    progressive_image_size_min = args.progressive_image_size_min
    progressive_image_size_step = args.progressive_image_size_step
    progressive_image_epoch_step = args.progressive_image_epoch_step
    batch_size = args.batch_size
    batch_iterations = args.batch_iterations
    test_size = args.test_size
    train_on_val = args.train_on_val
    fold = args.fold
    train_on_unrecognized = args.train_on_unrecognized
    confusion_set = args.confusion_set
    num_category_shards = args.num_category_shards
    category_shard = args.category_shard
    eval_train_mapk = args.eval_train_mapk
    mapk_topk = args.mapk_topk
    num_shard_preload = args.num_shard_preload
    num_shard_loaders = args.num_shard_loaders
    num_workers = args.num_workers
    pin_memory = args.pin_memory
    epochs_to_train = args.epochs
    lr_scheduler_type = args.lr_scheduler
    lr_patience = args.lr_patience
    lr_min = args.lr_min
    lr_max = args.lr_max
    lr_min_decay = args.lr_min_decay
    lr_max_decay = args.lr_max_decay
    optimizer_type = args.optimizer
    loss_type = args.loss
    bootstraping_loss_ratio = args.bootstraping_loss_ratio
    loss2_type = args.loss2
    loss2_start_sgdr_cycle = args.loss2_start_sgdr_cycle
    model_type = args.model
    patience = args.patience
    sgdr_cycle_epochs = args.sgdr_cycle_epochs
    sgdr_cycle_epochs_mult = args.sgdr_cycle_epochs_mult
    sgdr_cycle_end_prolongation = args.sgdr_cycle_end_prolongation
    sgdr_cycle_end_patience = args.sgdr_cycle_end_patience
    max_sgdr_cycles = args.max_sgdr_cycles

    use_extended_stroke_channels = model_type in ["cnn", "residual_cnn", "fc_cnn", "hc_fc_cnn"]
    print("use_extended_stroke_channels: {}".format(use_extended_stroke_channels), flush=True)

    progressive_image_sizes = list(range(progressive_image_size_min, image_size + 1, progressive_image_size_step))

    train_data_provider = TrainDataProvider(
        input_dir,
        50,
        num_shard_preload=num_shard_preload,
        num_workers=num_shard_loaders,
        test_size=test_size,
        fold=fold,
        train_on_unrecognized=train_on_unrecognized,
        confusion_set=confusion_set,
        num_category_shards=num_category_shards,
        category_shard=category_shard,
        train_on_val=train_on_val)

    train_data = train_data_provider.get_next()

    train_set = TrainDataset(train_data.train_set_df, len(train_data.categories), image_size, use_extended_stroke_channels, augment, use_dummy_image)
    stratified_sampler = StratifiedSampler(train_data.train_set_df["category"], batch_size * batch_iterations)
    train_set_data_loader = \
        DataLoader(train_set, batch_size=batch_size, shuffle=False, sampler=stratified_sampler, num_workers=num_workers,
                   pin_memory=pin_memory)

    val_set = TrainDataset(train_data.val_set_df, len(train_data.categories), image_size, use_extended_stroke_channels, False, use_dummy_image)
    val_set_data_loader = \
        DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    if base_model_dir:
        for base_file_path in glob.glob("{}/*.pth".format(base_model_dir)):
            shutil.copyfile(base_file_path, "{}/{}".format(output_dir, os.path.basename(base_file_path)))
        model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
        model.load_state_dict(torch.load("{}/model.pth".format(output_dir), map_location=device))
        optimizer = create_optimizer(optimizer_type, model, lr_max)
        if os.path.isfile("{}/optimizer.pth".format(output_dir)):
            optimizer.load_state_dict(torch.load("{}/optimizer.pth".format(output_dir)))
            adjust_initial_learning_rate(optimizer, lr_max)
            adjust_learning_rate(optimizer, lr_max)
    else:
        model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
        optimizer = create_optimizer(optimizer_type, model, lr_max)

    torch.save(model.state_dict(), "{}/model.pth".format(output_dir))

    ensemble_model_index = 0
    for model_file_path in glob.glob("{}/model-*.pth".format(output_dir)):
        model_file_name = os.path.basename(model_file_path)
        model_index = int(model_file_name.replace("model-", "").replace(".pth", ""))
        ensemble_model_index = max(ensemble_model_index, model_index + 1)

    if confusion_set is not None:
        shutil.copyfile(
            "/storage/models/quickdraw/seresnext50_confusion/confusion_set_{}.txt".format(confusion_set),
            "{}/confusion_set.txt".format(output_dir))

    epoch_iterations = ceil(len(train_set) / batch_size)

    print("train_set_samples: {}, val_set_samples: {}".format(len(train_set), len(val_set)), flush=True)
    print()

    global_val_mapk_best_avg = float("-inf")
    sgdr_cycle_val_mapk_best_avg = float("-inf")

    lr_scheduler = CosineAnnealingLR(optimizer, T_max=sgdr_cycle_epochs, eta_min=lr_min)

    optim_summary_writer = SummaryWriter(log_dir="{}/logs/optim".format(output_dir))
    train_summary_writer = SummaryWriter(log_dir="{}/logs/train".format(output_dir))
    val_summary_writer = SummaryWriter(log_dir="{}/logs/val".format(output_dir))

    current_sgdr_cycle_epochs = sgdr_cycle_epochs
    sgdr_next_cycle_end_epoch = current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation
    sgdr_iterations = 0
    sgdr_cycle_count = 0
    batch_count = 0
    epoch_of_last_improval = 0

    lr_scheduler_plateau = ReduceLROnPlateau(optimizer, mode="max", min_lr=lr_min, patience=lr_patience, factor=0.8, threshold=1e-4)

    print('{"chart": "best_val_mapk", "axis": "epoch"}')
    print('{"chart": "val_mapk", "axis": "epoch"}')
    print('{"chart": "val_loss", "axis": "epoch"}')
    print('{"chart": "val_accuracy@1", "axis": "epoch"}')
    print('{"chart": "val_accuracy@3", "axis": "epoch"}')
    print('{"chart": "val_accuracy@5", "axis": "epoch"}')
    print('{"chart": "val_accuracy@10", "axis": "epoch"}')
    print('{"chart": "sgdr_cycle", "axis": "epoch"}')
    print('{"chart": "mapk", "axis": "epoch"}')
    print('{"chart": "loss", "axis": "epoch"}')
    print('{"chart": "lr_scaled", "axis": "epoch"}')
    print('{"chart": "mem_used", "axis": "epoch"}')
    print('{"chart": "epoch_time", "axis": "epoch"}')

    train_start_time = time.time()

    criterion = create_criterion(loss_type, len(train_data.categories), bootstraping_loss_ratio)

    if loss_type == "center":
        optimizer_centloss = torch.optim.SGD(criterion.center.parameters(), lr=0.01)

    for epoch in range(epochs_to_train):
        epoch_start_time = time.time()

        print("memory used: {:.2f} GB".format(psutil.virtual_memory().used / 2 ** 30), flush=True)

        if use_progressive_image_sizes:
            next_image_size = \
                progressive_image_sizes[min(epoch // progressive_image_epoch_step, len(progressive_image_sizes) - 1)]

            if train_set.image_size != next_image_size:
                print("changing image size to {}".format(next_image_size), flush=True)
                train_set.image_size = next_image_size
                val_set.image_size = next_image_size

        model.train()

        train_loss_sum_t = zero_item_tensor()
        train_mapk_sum_t = zero_item_tensor()

        epoch_batch_iter_count = 0

        for b, batch in enumerate(train_set_data_loader):
            images, categories, categories_one_hot = \
                batch[0].to(device, non_blocking=True), \
                batch[1].to(device, non_blocking=True), \
                batch[2].to(device, non_blocking=True)

            if lr_scheduler_type == "cosine_annealing":
                lr_scheduler.step(epoch=min(current_sgdr_cycle_epochs, sgdr_iterations / epoch_iterations))

            if b % batch_iterations == 0:
                optimizer.zero_grad()

            prediction_logits = model(images)
            # if prediction_logits.size(1) == len(class_weights):
            #     criterion.weight = class_weights
            loss = criterion(prediction_logits, get_loss_target(criterion, categories, categories_one_hot))
            loss.backward()

            with torch.no_grad():
                train_loss_sum_t += loss
                if eval_train_mapk:
                    train_mapk_sum_t += mapk(prediction_logits, categories,
                                             topk=min(mapk_topk, len(train_data.categories)))

            if (b + 1) % batch_iterations == 0 or (b + 1) == len(train_set_data_loader):
                optimizer.step()
                if loss_type == "center":
                    for param in criterion.center.parameters():
                        param.grad.data *= (1. / 0.5)
                    optimizer_centloss.step()

            sgdr_iterations += 1
            batch_count += 1
            epoch_batch_iter_count += 1

            optim_summary_writer.add_scalar("lr", get_learning_rate(optimizer), batch_count + 1)

        # TODO: recalculate epoch_iterations and maybe other values?
        train_data = train_data_provider.get_next()
        train_set.df = train_data.train_set_df
        val_set.df = train_data.val_set_df
        epoch_iterations = ceil(len(train_set) / batch_size)
        stratified_sampler.class_vector = train_data.train_set_df["category"]

        train_loss_avg = train_loss_sum_t.item() / epoch_batch_iter_count
        train_mapk_avg = train_mapk_sum_t.item() / epoch_batch_iter_count

        val_loss_avg, val_mapk_avg, val_accuracy_top1_avg, val_accuracy_top3_avg, val_accuracy_top5_avg, val_accuracy_top10_avg = \
            evaluate(model, val_set_data_loader, criterion, mapk_topk)

        if lr_scheduler_type == "reduce_on_plateau":
            lr_scheduler_plateau.step(val_mapk_avg)

        model_improved_within_sgdr_cycle = check_model_improved(sgdr_cycle_val_mapk_best_avg, val_mapk_avg)
        if model_improved_within_sgdr_cycle:
            torch.save(model.state_dict(), "{}/model-{}.pth".format(output_dir, ensemble_model_index))
            sgdr_cycle_val_mapk_best_avg = val_mapk_avg

        model_improved = check_model_improved(global_val_mapk_best_avg, val_mapk_avg)
        ckpt_saved = False
        if model_improved:
            torch.save(model.state_dict(), "{}/model.pth".format(output_dir))
            torch.save(optimizer.state_dict(), "{}/optimizer.pth".format(output_dir))
            global_val_mapk_best_avg = val_mapk_avg
            epoch_of_last_improval = epoch
            ckpt_saved = True

        sgdr_reset = False
        if (lr_scheduler_type == "cosine_annealing") and (epoch + 1 >= sgdr_next_cycle_end_epoch) and (epoch - epoch_of_last_improval >= sgdr_cycle_end_patience):
            sgdr_iterations = 0
            current_sgdr_cycle_epochs = int(current_sgdr_cycle_epochs * sgdr_cycle_epochs_mult)
            sgdr_next_cycle_end_epoch = epoch + 1 + current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation

            ensemble_model_index += 1
            sgdr_cycle_val_mapk_best_avg = float("-inf")
            sgdr_cycle_count += 1
            sgdr_reset = True

            new_lr_min = lr_min * (lr_min_decay ** sgdr_cycle_count)
            new_lr_max = lr_max * (lr_max_decay ** sgdr_cycle_count)
            new_lr_max = max(new_lr_max, new_lr_min)

            adjust_learning_rate(optimizer, new_lr_max)
            lr_scheduler = CosineAnnealingLR(optimizer, T_max=current_sgdr_cycle_epochs, eta_min=new_lr_min)
            if loss2_type is not None and sgdr_cycle_count >= loss2_start_sgdr_cycle:
                print("switching to loss type '{}'".format(loss2_type), flush=True)
                criterion = create_criterion(loss2_type, len(train_data.categories), bootstraping_loss_ratio)

        optim_summary_writer.add_scalar("sgdr_cycle", sgdr_cycle_count, epoch + 1)

        train_summary_writer.add_scalar("loss", train_loss_avg, epoch + 1)
        train_summary_writer.add_scalar("mapk", train_mapk_avg, epoch + 1)
        val_summary_writer.add_scalar("loss", val_loss_avg, epoch + 1)
        val_summary_writer.add_scalar("mapk", val_mapk_avg, epoch + 1)

        epoch_end_time = time.time()
        epoch_duration_time = epoch_end_time - epoch_start_time

        print(
            "[%03d/%03d] %ds, lr: %.6f, loss: %.4f, val_loss: %.4f, acc: %.4f, val_acc: %.4f, ckpt: %d, rst: %d" % (
                epoch + 1,
                epochs_to_train,
                epoch_duration_time,
                get_learning_rate(optimizer),
                train_loss_avg,
                val_loss_avg,
                train_mapk_avg,
                val_mapk_avg,
                int(ckpt_saved),
                int(sgdr_reset)))

        print('{"chart": "best_val_mapk", "x": %d, "y": %.4f}' % (epoch + 1, global_val_mapk_best_avg))
        print('{"chart": "val_loss", "x": %d, "y": %.4f}' % (epoch + 1, val_loss_avg))
        print('{"chart": "val_mapk", "x": %d, "y": %.4f}' % (epoch + 1, val_mapk_avg))
        print('{"chart": "val_accuracy@1", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top1_avg))
        print('{"chart": "val_accuracy@3", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top3_avg))
        print('{"chart": "val_accuracy@5", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top5_avg))
        print('{"chart": "val_accuracy@10", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top10_avg))
        print('{"chart": "sgdr_cycle", "x": %d, "y": %d}' % (epoch + 1, sgdr_cycle_count))
        print('{"chart": "loss", "x": %d, "y": %.4f}' % (epoch + 1, train_loss_avg))
        print('{"chart": "mapk", "x": %d, "y": %.4f}' % (epoch + 1, train_mapk_avg))
        print('{"chart": "lr_scaled", "x": %d, "y": %.4f}' % (epoch + 1, 1000 * get_learning_rate(optimizer)))
        print('{"chart": "mem_used", "x": %d, "y": %.2f}' % (epoch + 1, psutil.virtual_memory().used / 2 ** 30))
        print('{"chart": "epoch_time", "x": %d, "y": %d}' % (epoch + 1, epoch_duration_time))

        sys.stdout.flush()

        if (sgdr_reset or lr_scheduler_type == "reduce_on_plateau") and epoch - epoch_of_last_improval >= patience:
            print("early abort due to lack of improval", flush=True)
            break

        if max_sgdr_cycles is not None and sgdr_cycle_count >= max_sgdr_cycles:
            print("early abort due to maximum number of sgdr cycles reached", flush=True)
            break

    optim_summary_writer.close()
    train_summary_writer.close()
    val_summary_writer.close()

    train_end_time = time.time()
    print()
    print("Train time: %s" % str(datetime.timedelta(seconds=train_end_time - train_start_time)), flush=True)

    if False:
        swa_model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(
            device)
        swa_update_count = 0
        for f in find_sorted_model_files(output_dir):
            print("merging model '{}' into swa model".format(f), flush=True)
            m = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
            m.load_state_dict(torch.load(f, map_location=device))
            swa_update_count += 1
            moving_average(swa_model, m, 1.0 / swa_update_count)
            # bn_update(train_set_data_loader, swa_model)
        torch.save(swa_model.state_dict(), "{}/swa_model.pth".format(output_dir))

    test_data = TestData(input_dir)
    test_set = TestDataset(test_data.df, image_size, use_extended_stroke_channels)
    test_set_data_loader = \
        DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    model.load_state_dict(torch.load("{}/model.pth".format(output_dir), map_location=device))
    model = Ensemble([model])

    categories = train_data.categories

    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=False)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission.csv".format(output_dir), columns=["word"])

    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=True)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions_tta.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission_tta.csv".format(output_dir), columns=["word"])

    val_set_data_loader = \
        DataLoader(val_set, batch_size=64, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    model = load_ensemble_model(output_dir, 3, val_set_data_loader, criterion, model_type, image_size, len(categories))
    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=True)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions_ensemble_tta.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission_ensemble_tta.csv".format(output_dir), columns=["word"])

    confusion, _ = calculate_confusion(model, val_set_data_loader, len(categories))
    precisions = np.array([confusion[c, c] for c in range(confusion.shape[0])])
    percentiles = np.percentile(precisions, q=np.linspace(0, 100, 10))

    print()
    print("Category precision percentiles:")
    print(percentiles)

    print()
    print("Categories sorted by precision:")
    print(np.array(categories)[np.argsort(precisions)])