Beispiel #1
0
def main():
    args = argparser.parse_args()
    log_args(args)

    input_dir = args.input_dir
    output_dir = args.output_dir
    base_model_dir = args.base_model_dir
    image_size = args.image_size
    crop_images = args.crop_images
    augment = args.augment
    use_progressive_image_sizes = args.use_progressive_image_sizes
    progressive_image_size_min = args.progressive_image_size_min
    progressive_image_size_step = args.progressive_image_size_step
    progressive_image_epoch_step = args.progressive_image_epoch_step
    batch_size = args.batch_size
    batch_iterations = args.batch_iterations
    num_workers = args.num_workers
    pin_memory = args.pin_memory
    epochs_to_train = args.epochs
    lr_scheduler_type = args.lr_scheduler
    lr_patience = args.lr_patience
    lr_min = args.lr_min
    lr_max = args.lr_max
    lr_min_decay = args.lr_min_decay
    lr_max_decay = args.lr_max_decay
    optimizer_type = args.optimizer
    loss_type = args.loss
    focal_loss_gamma = args.focal_loss_gamma
    use_class_weights = args.use_class_weights
    use_weighted_sampling = args.use_weighted_sampling
    model_type = args.model
    patience = args.patience
    sgdr_cycle_epochs = args.sgdr_cycle_epochs
    sgdr_cycle_epochs_mult = args.sgdr_cycle_epochs_mult
    sgdr_cycle_end_prolongation = args.sgdr_cycle_end_prolongation
    sgdr_cycle_end_patience = args.sgdr_cycle_end_patience
    max_sgdr_cycles = args.max_sgdr_cycles

    if optimizer_type == "adam":
        lr_scheduler_type = "adam"

    progressive_image_sizes = list(
        range(progressive_image_size_min, image_size + 1,
              progressive_image_size_step))

    train_data = TrainData(input_dir)

    train_set = TrainDataset(train_data.train_set_df, input_dir, 28,
                             image_size, crop_images, augment)

    balance_weights, balance_class_weights = calculate_balance_weights(
        train_data.df, train_data.train_set_df, 28)
    train_set_sampler = WeightedRandomSampler(balance_weights,
                                              len(balance_weights))

    train_set_data_loader = DataLoader(
        train_set,
        batch_size=batch_size,
        shuffle=False if use_weighted_sampling else True,
        sampler=train_set_sampler if use_weighted_sampling else None,
        num_workers=num_workers,
        pin_memory=pin_memory)

    val_set = TrainDataset(train_data.val_set_df, input_dir, 28, image_size,
                           crop_images, False)
    val_set_data_loader = \
        DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    if base_model_dir:
        for base_file_path in glob.glob("{}/*.pth".format(base_model_dir)):
            shutil.copyfile(
                base_file_path,
                "{}/{}".format(output_dir, os.path.basename(base_file_path)))
        model = create_model(type=model_type, num_classes=28).to(device)
        model.load_state_dict(
            torch.load("{}/model.pth".format(output_dir), map_location=device))
        optimizer = create_optimizer(optimizer_type, model, lr_max)
        if os.path.isfile("{}/optimizer.pth".format(output_dir)):
            try:
                optimizer.load_state_dict(
                    torch.load("{}/optimizer.pth".format(output_dir)))
                adjust_initial_learning_rate(optimizer, lr_max)
                adjust_learning_rate(optimizer, lr_max)
            except:
                log("Failed to load the optimizer weights")
    else:
        model = create_model(type=model_type, num_classes=28).to(device)
        optimizer = create_optimizer(optimizer_type, model, lr_max)

    torch.save(model.state_dict(), "{}/model.pth".format(output_dir))

    ensemble_model_index = 0
    for model_file_path in glob.glob("{}/model-*.pth".format(output_dir)):
        model_file_name = os.path.basename(model_file_path)
        model_index = int(
            model_file_name.replace("model-", "").replace(".pth", ""))
        ensemble_model_index = max(ensemble_model_index, model_index + 1)

    epoch_iterations = ceil(len(train_set) / batch_size)

    log("train_set_samples: {}, val_set_samples: {}".format(
        len(train_set), len(val_set)))
    log()

    global_val_score_best_avg = float("-inf")
    sgdr_cycle_val_score_best_avg = float("-inf")

    lr_scheduler = CosineAnnealingLR(optimizer,
                                     T_max=sgdr_cycle_epochs,
                                     eta_min=lr_min)

    optim_summary_writer = SummaryWriter(
        log_dir="{}/logs/optim".format(output_dir))
    train_summary_writer = SummaryWriter(
        log_dir="{}/logs/train".format(output_dir))
    val_summary_writer = SummaryWriter(
        log_dir="{}/logs/val".format(output_dir))

    current_sgdr_cycle_epochs = sgdr_cycle_epochs
    sgdr_next_cycle_end_epoch = current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation
    sgdr_iterations = 0
    sgdr_cycle_count = 0
    batch_count = 0
    epoch_of_last_improval = 0

    lr_scheduler_plateau = \
        ReduceLROnPlateau(optimizer, mode="max", min_lr=lr_min, patience=lr_patience, factor=0.5, threshold=1e-4)

    lr_scheduler_step = StepLR(optimizer, step_size=10, gamma=0.1)

    log('{"chart": "best_val_score", "axis": "epoch"}')
    log('{"chart": "val_score", "axis": "epoch"}')
    log('{"chart": "val_loss", "axis": "epoch"}')
    log('{"chart": "sgdr_cycle", "axis": "epoch"}')
    log('{"chart": "score", "axis": "epoch"}')
    log('{"chart": "loss", "axis": "epoch"}')
    log('{"chart": "lr_scaled", "axis": "epoch"}')
    log('{"chart": "mem_used", "axis": "epoch"}')
    log('{"chart": "epoch_time", "axis": "epoch"}')

    train_start_time = time.time()

    loss_weight = CLASS_WEIGHTS_TENSOR if use_class_weights else None
    criterion = create_criterion(loss_type, loss_weight, focal_loss_gamma)

    for epoch in range(epochs_to_train):
        epoch_start_time = time.time()

        log("memory used: {:.2f} GB".format(psutil.virtual_memory().used /
                                            2**30))

        if use_progressive_image_sizes:
            next_image_size = \
                progressive_image_sizes[min(epoch // progressive_image_epoch_step, len(progressive_image_sizes) - 1)]

            if train_set.image_size != next_image_size:
                log("changing image size to {}".format(next_image_size))
                train_set.image_size = next_image_size
                val_set.image_size = next_image_size

        model.train()

        train_loss_sum_t = zero_item_tensor()

        epoch_batch_iter_count = 0

        if lr_scheduler_type == "lr_finder":
            new_lr = lr_max * 0.5**(sgdr_cycle_epochs - min(
                sgdr_cycle_epochs, sgdr_iterations / epoch_iterations))
            adjust_learning_rate(optimizer, new_lr)

        all_predictions = []
        all_targets = []
        for b, batch in enumerate(train_set_data_loader):
            images, categories = \
                batch[0].to(device, non_blocking=True), \
                batch[1].to(device, non_blocking=True)

            if lr_scheduler_type == "cosine_annealing":
                lr_scheduler.step(
                    epoch=min(current_sgdr_cycle_epochs, sgdr_iterations /
                              epoch_iterations))

            if b % batch_iterations == 0:
                optimizer.zero_grad()

            prediction_logits = model(images)
            criterion.weight = CLASS_WEIGHTS_TENSOR
            loss = criterion(prediction_logits, categories)
            loss.backward()

            with torch.no_grad():
                train_loss_sum_t += loss
                all_predictions.extend(
                    torch.sigmoid(prediction_logits).cpu().data.numpy())
                all_targets.extend(categories.cpu().data.numpy())

            if (b + 1) % batch_iterations == 0 or (
                    b + 1) == len(train_set_data_loader):
                optimizer.step()

            sgdr_iterations += 1
            batch_count += 1
            epoch_batch_iter_count += 1

            optim_summary_writer.add_scalar("lr", get_learning_rate(optimizer),
                                            batch_count + 1)

        train_loss_avg = train_loss_sum_t.item() / epoch_batch_iter_count
        train_score_avg = f1_score_from_probs(torch.tensor(all_predictions),
                                              torch.tensor(all_targets))

        val_loss_avg, val_score_avg = evaluate(model, val_set_data_loader,
                                               criterion)

        if lr_scheduler_type == "reduce_on_plateau":
            lr_scheduler_plateau.step(val_score_avg)
        elif lr_scheduler_type == "step":
            lr_scheduler_step.step(epoch)

        model_improved_within_sgdr_cycle = check_model_improved(
            sgdr_cycle_val_score_best_avg, val_score_avg)
        if model_improved_within_sgdr_cycle:
            torch.save(
                model.state_dict(),
                "{}/model-{}.pth".format(output_dir, ensemble_model_index))
            sgdr_cycle_val_score_best_avg = val_score_avg

        model_improved = check_model_improved(global_val_score_best_avg,
                                              val_score_avg)
        ckpt_saved = False
        if model_improved:
            torch.save(model.state_dict(), "{}/model.pth".format(output_dir))
            torch.save(optimizer.state_dict(),
                       "{}/optimizer.pth".format(output_dir))
            np.save("{}/train_predictions.npy".format(output_dir),
                    all_predictions)
            np.save("{}/train_targets.npy".format(output_dir), all_targets)
            global_val_score_best_avg = val_score_avg
            epoch_of_last_improval = epoch
            ckpt_saved = True

        sgdr_reset = False
        if (lr_scheduler_type == "cosine_annealing") \
                and (epoch + 1 >= sgdr_next_cycle_end_epoch) \
                and (epoch - epoch_of_last_improval >= sgdr_cycle_end_patience):
            sgdr_iterations = 0
            current_sgdr_cycle_epochs = int(current_sgdr_cycle_epochs *
                                            sgdr_cycle_epochs_mult)
            sgdr_next_cycle_end_epoch = epoch + 1 + current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation

            ensemble_model_index += 1
            sgdr_cycle_val_score_best_avg = float("-inf")
            sgdr_cycle_count += 1
            sgdr_reset = True

            new_lr_min = lr_min * (lr_min_decay**sgdr_cycle_count)
            new_lr_max = lr_max * (lr_max_decay**sgdr_cycle_count)
            new_lr_max = max(new_lr_max, new_lr_min)

            adjust_learning_rate(optimizer, new_lr_max)
            lr_scheduler = CosineAnnealingLR(optimizer,
                                             T_max=current_sgdr_cycle_epochs,
                                             eta_min=new_lr_min)

        optim_summary_writer.add_scalar("sgdr_cycle", sgdr_cycle_count,
                                        epoch + 1)

        train_summary_writer.add_scalar("loss", train_loss_avg, epoch + 1)
        train_summary_writer.add_scalar("score", train_score_avg, epoch + 1)
        val_summary_writer.add_scalar("loss", val_loss_avg, epoch + 1)
        val_summary_writer.add_scalar("score", val_score_avg, epoch + 1)

        epoch_end_time = time.time()
        epoch_duration_time = epoch_end_time - epoch_start_time

        log("[%03d/%03d] %ds, lr: %.6f, loss: %.4f, val_loss: %.4f, score: %.4f, val_score: %.4f, ckpt: %d, rst: %d"
            %
            (epoch + 1, epochs_to_train, epoch_duration_time,
             get_learning_rate(optimizer), train_loss_avg, val_loss_avg,
             train_score_avg, val_score_avg, int(ckpt_saved), int(sgdr_reset)))

        log('{"chart": "best_val_score", "x": %d, "y": %.4f}' %
            (epoch + 1, global_val_score_best_avg))
        log('{"chart": "val_loss", "x": %d, "y": %.4f}' %
            (epoch + 1, val_loss_avg))
        log('{"chart": "val_score", "x": %d, "y": %.4f}' %
            (epoch + 1, val_score_avg))
        log('{"chart": "sgdr_cycle", "x": %d, "y": %d}' %
            (epoch + 1, sgdr_cycle_count))
        log('{"chart": "loss", "x": %d, "y": %.4f}' %
            (epoch + 1, train_loss_avg))
        log('{"chart": "score", "x": %d, "y": %.4f}' %
            (epoch + 1, train_score_avg))
        log('{"chart": "lr_scaled", "x": %d, "y": %.4f}' %
            (epoch + 1, 1000 * get_learning_rate(optimizer)))
        log('{"chart": "mem_used", "x": %d, "y": %.2f}' %
            (epoch + 1, psutil.virtual_memory().used / 2**30))
        log('{"chart": "epoch_time", "x": %d, "y": %d}' %
            (epoch + 1, epoch_duration_time))

        if (sgdr_reset or lr_scheduler_type in ("reduce_on_plateau", "step")) \
                and epoch - epoch_of_last_improval >= patience:
            log("early abort due to lack of improval")
            break

        if max_sgdr_cycles is not None and sgdr_cycle_count >= max_sgdr_cycles:
            log("early abort due to maximum number of sgdr cycles reached")
            break

    optim_summary_writer.close()
    train_summary_writer.close()
    val_summary_writer.close()

    train_end_time = time.time()
    log()
    log("Train time: %s" %
        str(datetime.timedelta(seconds=train_end_time - train_start_time)))

    model.load_state_dict(
        torch.load("{}/model.pth".format(output_dir), map_location=device))

    val_predictions, val_targets = predict(model, val_set_data_loader)
    np.save("{}/val_predictions.npy".format(output_dir), val_predictions)
    np.save("{}/val_targets.npy".format(output_dir), val_targets)

    best_threshold, best_threshold_score, all_threshold_scores = calculate_best_threshold(
        val_predictions, val_targets)
    log("All threshold scores: {}".format(all_threshold_scores))
    log("Best threshold / score: {} / {}".format(best_threshold,
                                                 best_threshold_score))

    test_data = TestData(input_dir)
    test_set = TestDataset(test_data.test_set_df, input_dir, image_size,
                           crop_images)
    test_set_data_loader = \
        DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    test_predictions, _ = predict(model, test_set_data_loader)
    np.save("{}/test_predictions.npy".format(output_dir), test_predictions)

    predicted_categories = calculate_categories_from_predictions(
        test_predictions, threshold=best_threshold)

    submission_df = test_data.test_set_df.copy()
    submission_df["Predicted"] = [
        " ".join(map(str, pc)) for pc in predicted_categories
    ]
    submission_df.to_csv("{}/submission.csv".format(output_dir))
Beispiel #2
0
def main():
    args = argparser.parse_args()
    print("Arguments:")
    for arg in vars(args):
        print("  {}: {}".format(arg, getattr(args, arg)))
    print()

    input_dir = args.input_dir
    output_dir = args.output_dir
    base_model_dir = args.base_model_dir
    image_size = args.image_size
    augment = args.augment
    use_dummy_image = args.use_dummy_image
    use_progressive_image_sizes = args.use_progressive_image_sizes
    progressive_image_size_min = args.progressive_image_size_min
    progressive_image_size_step = args.progressive_image_size_step
    progressive_image_epoch_step = args.progressive_image_epoch_step
    batch_size = args.batch_size
    batch_iterations = args.batch_iterations
    test_size = args.test_size
    train_on_val = args.train_on_val
    fold = args.fold
    train_on_unrecognized = args.train_on_unrecognized
    confusion_set = args.confusion_set
    num_category_shards = args.num_category_shards
    category_shard = args.category_shard
    eval_train_mapk = args.eval_train_mapk
    mapk_topk = args.mapk_topk
    num_shard_preload = args.num_shard_preload
    num_shard_loaders = args.num_shard_loaders
    num_workers = args.num_workers
    pin_memory = args.pin_memory
    epochs_to_train = args.epochs
    lr_scheduler_type = args.lr_scheduler
    lr_patience = args.lr_patience
    lr_min = args.lr_min
    lr_max = args.lr_max
    lr_min_decay = args.lr_min_decay
    lr_max_decay = args.lr_max_decay
    optimizer_type = args.optimizer
    loss_type = args.loss
    bootstraping_loss_ratio = args.bootstraping_loss_ratio
    loss2_type = args.loss2
    loss2_start_sgdr_cycle = args.loss2_start_sgdr_cycle
    model_type = args.model
    patience = args.patience
    sgdr_cycle_epochs = args.sgdr_cycle_epochs
    sgdr_cycle_epochs_mult = args.sgdr_cycle_epochs_mult
    sgdr_cycle_end_prolongation = args.sgdr_cycle_end_prolongation
    sgdr_cycle_end_patience = args.sgdr_cycle_end_patience
    max_sgdr_cycles = args.max_sgdr_cycles

    use_extended_stroke_channels = model_type in ["cnn", "residual_cnn", "fc_cnn", "hc_fc_cnn"]
    print("use_extended_stroke_channels: {}".format(use_extended_stroke_channels), flush=True)

    progressive_image_sizes = list(range(progressive_image_size_min, image_size + 1, progressive_image_size_step))

    train_data_provider = TrainDataProvider(
        input_dir,
        50,
        num_shard_preload=num_shard_preload,
        num_workers=num_shard_loaders,
        test_size=test_size,
        fold=fold,
        train_on_unrecognized=train_on_unrecognized,
        confusion_set=confusion_set,
        num_category_shards=num_category_shards,
        category_shard=category_shard,
        train_on_val=train_on_val)

    train_data = train_data_provider.get_next()

    train_set = TrainDataset(train_data.train_set_df, len(train_data.categories), image_size, use_extended_stroke_channels, augment, use_dummy_image)
    stratified_sampler = StratifiedSampler(train_data.train_set_df["category"], batch_size * batch_iterations)
    train_set_data_loader = \
        DataLoader(train_set, batch_size=batch_size, shuffle=False, sampler=stratified_sampler, num_workers=num_workers,
                   pin_memory=pin_memory)

    val_set = TrainDataset(train_data.val_set_df, len(train_data.categories), image_size, use_extended_stroke_channels, False, use_dummy_image)
    val_set_data_loader = \
        DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    if base_model_dir:
        for base_file_path in glob.glob("{}/*.pth".format(base_model_dir)):
            shutil.copyfile(base_file_path, "{}/{}".format(output_dir, os.path.basename(base_file_path)))
        model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
        model.load_state_dict(torch.load("{}/model.pth".format(output_dir), map_location=device))
        optimizer = create_optimizer(optimizer_type, model, lr_max)
        if os.path.isfile("{}/optimizer.pth".format(output_dir)):
            optimizer.load_state_dict(torch.load("{}/optimizer.pth".format(output_dir)))
            adjust_initial_learning_rate(optimizer, lr_max)
            adjust_learning_rate(optimizer, lr_max)
    else:
        model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
        optimizer = create_optimizer(optimizer_type, model, lr_max)

    torch.save(model.state_dict(), "{}/model.pth".format(output_dir))

    ensemble_model_index = 0
    for model_file_path in glob.glob("{}/model-*.pth".format(output_dir)):
        model_file_name = os.path.basename(model_file_path)
        model_index = int(model_file_name.replace("model-", "").replace(".pth", ""))
        ensemble_model_index = max(ensemble_model_index, model_index + 1)

    if confusion_set is not None:
        shutil.copyfile(
            "/storage/models/quickdraw/seresnext50_confusion/confusion_set_{}.txt".format(confusion_set),
            "{}/confusion_set.txt".format(output_dir))

    epoch_iterations = ceil(len(train_set) / batch_size)

    print("train_set_samples: {}, val_set_samples: {}".format(len(train_set), len(val_set)), flush=True)
    print()

    global_val_mapk_best_avg = float("-inf")
    sgdr_cycle_val_mapk_best_avg = float("-inf")

    lr_scheduler = CosineAnnealingLR(optimizer, T_max=sgdr_cycle_epochs, eta_min=lr_min)

    optim_summary_writer = SummaryWriter(log_dir="{}/logs/optim".format(output_dir))
    train_summary_writer = SummaryWriter(log_dir="{}/logs/train".format(output_dir))
    val_summary_writer = SummaryWriter(log_dir="{}/logs/val".format(output_dir))

    current_sgdr_cycle_epochs = sgdr_cycle_epochs
    sgdr_next_cycle_end_epoch = current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation
    sgdr_iterations = 0
    sgdr_cycle_count = 0
    batch_count = 0
    epoch_of_last_improval = 0

    lr_scheduler_plateau = ReduceLROnPlateau(optimizer, mode="max", min_lr=lr_min, patience=lr_patience, factor=0.8, threshold=1e-4)

    print('{"chart": "best_val_mapk", "axis": "epoch"}')
    print('{"chart": "val_mapk", "axis": "epoch"}')
    print('{"chart": "val_loss", "axis": "epoch"}')
    print('{"chart": "val_accuracy@1", "axis": "epoch"}')
    print('{"chart": "val_accuracy@3", "axis": "epoch"}')
    print('{"chart": "val_accuracy@5", "axis": "epoch"}')
    print('{"chart": "val_accuracy@10", "axis": "epoch"}')
    print('{"chart": "sgdr_cycle", "axis": "epoch"}')
    print('{"chart": "mapk", "axis": "epoch"}')
    print('{"chart": "loss", "axis": "epoch"}')
    print('{"chart": "lr_scaled", "axis": "epoch"}')
    print('{"chart": "mem_used", "axis": "epoch"}')
    print('{"chart": "epoch_time", "axis": "epoch"}')

    train_start_time = time.time()

    criterion = create_criterion(loss_type, len(train_data.categories), bootstraping_loss_ratio)

    if loss_type == "center":
        optimizer_centloss = torch.optim.SGD(criterion.center.parameters(), lr=0.01)

    for epoch in range(epochs_to_train):
        epoch_start_time = time.time()

        print("memory used: {:.2f} GB".format(psutil.virtual_memory().used / 2 ** 30), flush=True)

        if use_progressive_image_sizes:
            next_image_size = \
                progressive_image_sizes[min(epoch // progressive_image_epoch_step, len(progressive_image_sizes) - 1)]

            if train_set.image_size != next_image_size:
                print("changing image size to {}".format(next_image_size), flush=True)
                train_set.image_size = next_image_size
                val_set.image_size = next_image_size

        model.train()

        train_loss_sum_t = zero_item_tensor()
        train_mapk_sum_t = zero_item_tensor()

        epoch_batch_iter_count = 0

        for b, batch in enumerate(train_set_data_loader):
            images, categories, categories_one_hot = \
                batch[0].to(device, non_blocking=True), \
                batch[1].to(device, non_blocking=True), \
                batch[2].to(device, non_blocking=True)

            if lr_scheduler_type == "cosine_annealing":
                lr_scheduler.step(epoch=min(current_sgdr_cycle_epochs, sgdr_iterations / epoch_iterations))

            if b % batch_iterations == 0:
                optimizer.zero_grad()

            prediction_logits = model(images)
            # if prediction_logits.size(1) == len(class_weights):
            #     criterion.weight = class_weights
            loss = criterion(prediction_logits, get_loss_target(criterion, categories, categories_one_hot))
            loss.backward()

            with torch.no_grad():
                train_loss_sum_t += loss
                if eval_train_mapk:
                    train_mapk_sum_t += mapk(prediction_logits, categories,
                                             topk=min(mapk_topk, len(train_data.categories)))

            if (b + 1) % batch_iterations == 0 or (b + 1) == len(train_set_data_loader):
                optimizer.step()
                if loss_type == "center":
                    for param in criterion.center.parameters():
                        param.grad.data *= (1. / 0.5)
                    optimizer_centloss.step()

            sgdr_iterations += 1
            batch_count += 1
            epoch_batch_iter_count += 1

            optim_summary_writer.add_scalar("lr", get_learning_rate(optimizer), batch_count + 1)

        # TODO: recalculate epoch_iterations and maybe other values?
        train_data = train_data_provider.get_next()
        train_set.df = train_data.train_set_df
        val_set.df = train_data.val_set_df
        epoch_iterations = ceil(len(train_set) / batch_size)
        stratified_sampler.class_vector = train_data.train_set_df["category"]

        train_loss_avg = train_loss_sum_t.item() / epoch_batch_iter_count
        train_mapk_avg = train_mapk_sum_t.item() / epoch_batch_iter_count

        val_loss_avg, val_mapk_avg, val_accuracy_top1_avg, val_accuracy_top3_avg, val_accuracy_top5_avg, val_accuracy_top10_avg = \
            evaluate(model, val_set_data_loader, criterion, mapk_topk)

        if lr_scheduler_type == "reduce_on_plateau":
            lr_scheduler_plateau.step(val_mapk_avg)

        model_improved_within_sgdr_cycle = check_model_improved(sgdr_cycle_val_mapk_best_avg, val_mapk_avg)
        if model_improved_within_sgdr_cycle:
            torch.save(model.state_dict(), "{}/model-{}.pth".format(output_dir, ensemble_model_index))
            sgdr_cycle_val_mapk_best_avg = val_mapk_avg

        model_improved = check_model_improved(global_val_mapk_best_avg, val_mapk_avg)
        ckpt_saved = False
        if model_improved:
            torch.save(model.state_dict(), "{}/model.pth".format(output_dir))
            torch.save(optimizer.state_dict(), "{}/optimizer.pth".format(output_dir))
            global_val_mapk_best_avg = val_mapk_avg
            epoch_of_last_improval = epoch
            ckpt_saved = True

        sgdr_reset = False
        if (lr_scheduler_type == "cosine_annealing") and (epoch + 1 >= sgdr_next_cycle_end_epoch) and (epoch - epoch_of_last_improval >= sgdr_cycle_end_patience):
            sgdr_iterations = 0
            current_sgdr_cycle_epochs = int(current_sgdr_cycle_epochs * sgdr_cycle_epochs_mult)
            sgdr_next_cycle_end_epoch = epoch + 1 + current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation

            ensemble_model_index += 1
            sgdr_cycle_val_mapk_best_avg = float("-inf")
            sgdr_cycle_count += 1
            sgdr_reset = True

            new_lr_min = lr_min * (lr_min_decay ** sgdr_cycle_count)
            new_lr_max = lr_max * (lr_max_decay ** sgdr_cycle_count)
            new_lr_max = max(new_lr_max, new_lr_min)

            adjust_learning_rate(optimizer, new_lr_max)
            lr_scheduler = CosineAnnealingLR(optimizer, T_max=current_sgdr_cycle_epochs, eta_min=new_lr_min)
            if loss2_type is not None and sgdr_cycle_count >= loss2_start_sgdr_cycle:
                print("switching to loss type '{}'".format(loss2_type), flush=True)
                criterion = create_criterion(loss2_type, len(train_data.categories), bootstraping_loss_ratio)

        optim_summary_writer.add_scalar("sgdr_cycle", sgdr_cycle_count, epoch + 1)

        train_summary_writer.add_scalar("loss", train_loss_avg, epoch + 1)
        train_summary_writer.add_scalar("mapk", train_mapk_avg, epoch + 1)
        val_summary_writer.add_scalar("loss", val_loss_avg, epoch + 1)
        val_summary_writer.add_scalar("mapk", val_mapk_avg, epoch + 1)

        epoch_end_time = time.time()
        epoch_duration_time = epoch_end_time - epoch_start_time

        print(
            "[%03d/%03d] %ds, lr: %.6f, loss: %.4f, val_loss: %.4f, acc: %.4f, val_acc: %.4f, ckpt: %d, rst: %d" % (
                epoch + 1,
                epochs_to_train,
                epoch_duration_time,
                get_learning_rate(optimizer),
                train_loss_avg,
                val_loss_avg,
                train_mapk_avg,
                val_mapk_avg,
                int(ckpt_saved),
                int(sgdr_reset)))

        print('{"chart": "best_val_mapk", "x": %d, "y": %.4f}' % (epoch + 1, global_val_mapk_best_avg))
        print('{"chart": "val_loss", "x": %d, "y": %.4f}' % (epoch + 1, val_loss_avg))
        print('{"chart": "val_mapk", "x": %d, "y": %.4f}' % (epoch + 1, val_mapk_avg))
        print('{"chart": "val_accuracy@1", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top1_avg))
        print('{"chart": "val_accuracy@3", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top3_avg))
        print('{"chart": "val_accuracy@5", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top5_avg))
        print('{"chart": "val_accuracy@10", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top10_avg))
        print('{"chart": "sgdr_cycle", "x": %d, "y": %d}' % (epoch + 1, sgdr_cycle_count))
        print('{"chart": "loss", "x": %d, "y": %.4f}' % (epoch + 1, train_loss_avg))
        print('{"chart": "mapk", "x": %d, "y": %.4f}' % (epoch + 1, train_mapk_avg))
        print('{"chart": "lr_scaled", "x": %d, "y": %.4f}' % (epoch + 1, 1000 * get_learning_rate(optimizer)))
        print('{"chart": "mem_used", "x": %d, "y": %.2f}' % (epoch + 1, psutil.virtual_memory().used / 2 ** 30))
        print('{"chart": "epoch_time", "x": %d, "y": %d}' % (epoch + 1, epoch_duration_time))

        sys.stdout.flush()

        if (sgdr_reset or lr_scheduler_type == "reduce_on_plateau") and epoch - epoch_of_last_improval >= patience:
            print("early abort due to lack of improval", flush=True)
            break

        if max_sgdr_cycles is not None and sgdr_cycle_count >= max_sgdr_cycles:
            print("early abort due to maximum number of sgdr cycles reached", flush=True)
            break

    optim_summary_writer.close()
    train_summary_writer.close()
    val_summary_writer.close()

    train_end_time = time.time()
    print()
    print("Train time: %s" % str(datetime.timedelta(seconds=train_end_time - train_start_time)), flush=True)

    if False:
        swa_model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(
            device)
        swa_update_count = 0
        for f in find_sorted_model_files(output_dir):
            print("merging model '{}' into swa model".format(f), flush=True)
            m = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
            m.load_state_dict(torch.load(f, map_location=device))
            swa_update_count += 1
            moving_average(swa_model, m, 1.0 / swa_update_count)
            # bn_update(train_set_data_loader, swa_model)
        torch.save(swa_model.state_dict(), "{}/swa_model.pth".format(output_dir))

    test_data = TestData(input_dir)
    test_set = TestDataset(test_data.df, image_size, use_extended_stroke_channels)
    test_set_data_loader = \
        DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    model.load_state_dict(torch.load("{}/model.pth".format(output_dir), map_location=device))
    model = Ensemble([model])

    categories = train_data.categories

    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=False)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission.csv".format(output_dir), columns=["word"])

    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=True)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions_tta.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission_tta.csv".format(output_dir), columns=["word"])

    val_set_data_loader = \
        DataLoader(val_set, batch_size=64, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    model = load_ensemble_model(output_dir, 3, val_set_data_loader, criterion, model_type, image_size, len(categories))
    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=True)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions_ensemble_tta.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission_ensemble_tta.csv".format(output_dir), columns=["word"])

    confusion, _ = calculate_confusion(model, val_set_data_loader, len(categories))
    precisions = np.array([confusion[c, c] for c in range(confusion.shape[0])])
    percentiles = np.percentile(precisions, q=np.linspace(0, 100, 10))

    print()
    print("Category precision percentiles:")
    print(percentiles)

    print()
    print("Categories sorted by precision:")
    print(np.array(categories)[np.argsort(precisions)])