def ensemble_model(self, features_train, features_test, comments_train,
                    comments_test, y_train, y_test, w1, w2):
     score = Ensemble.get_ensemble_score('AVERAGING', comments_train,
                                         features_train, comments_test,
                                         features_test, y_train, y_test, w1,
                                         w2, self.imbalance_sampling)
     return score
示例#2
0
    def test(self, epoch):
        for m in self.models:
            m.eval()

        ensemble = Ensemble(self.models)

        loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for _, (inputs, targets) in enumerate(self.testloader):
                inputs, targets = inputs.cuda(), targets.cuda()

                outputs = ensemble(inputs)
                loss += self.criterion(outputs, targets).item()
                _, predicted = outputs.max(1)
                correct += predicted.eq(targets).sum().item()

                total += inputs.size(0)

        self.writer.add_scalar('test/ensemble_loss',
                               loss / len(self.testloader), epoch)
        self.writer.add_scalar('test/ensemble_acc', 100 * correct / total,
                               epoch)

        print_message = 'Evaluation  | Ensemble Loss {loss:.4f} Acc {acc:.2%}'.format(
            loss=loss / len(self.testloader), acc=correct / total)
        tqdm.write(print_message)
示例#3
0
def load_ensemble_model(base_dir, ensemble_model_count, data_loader, criterion, model_type, input_size, num_classes):
    ensemble_model_candidates = find_sorted_model_files(base_dir)[-(2 * ensemble_model_count):]
    if os.path.isfile("{}/swa_model.pth".format(base_dir)):
        ensemble_model_candidates.append("{}/swa_model.pth".format(base_dir))

    score_to_model = {}
    for model_file_path in ensemble_model_candidates:
        model_file_name = os.path.basename(model_file_path)
        model = create_model(type=model_type, input_size=input_size, num_classes=num_classes).to(device)
        model.load_state_dict(torch.load(model_file_path, map_location=device))

        val_loss_avg, val_mapk_avg, _, _, _, _ = evaluate(model, data_loader, criterion, 3)
        print("ensemble '%s': val_loss=%.4f, val_mapk=%.4f" % (model_file_name, val_loss_avg, val_mapk_avg))

        if len(score_to_model) < ensemble_model_count or min(score_to_model.keys()) < val_mapk_avg:
            if len(score_to_model) >= ensemble_model_count:
                del score_to_model[min(score_to_model.keys())]
            score_to_model[val_mapk_avg] = model

    ensemble = Ensemble(list(score_to_model.values()))

    val_loss_avg, val_mapk_avg, _, _, _, _ = evaluate(ensemble, data_loader, criterion, 3)
    print("ensemble: val_loss=%.4f, val_mapk=%.4f" % (val_loss_avg, val_mapk_avg))

    return ensemble
示例#4
0
def getEnsembleContext(ensemble_path):
    for file in ensemble_path.iterdir():

        context = TorchContext(device,
                               file_path=file,
                               variables=dict(DATASET_PATH=args.dataset_path,
                                              CHECKPOINTS_PATH=""))
        context.init_components()
        models.append(context.model)

    ensemble_model = Ensemble(models)

    context.model = ensemble_model
    return context
示例#5
0
文件: utils.py 项目: zjysteven/DVERGE
def get_models(args, train=True, as_ensemble=False, model_file=None, leaky_relu=False):
    models = []
    
    mean = torch.tensor([0.4914, 0.4822, 0.4465], dtype=torch.float32).cuda()
    std = torch.tensor([0.2023, 0.1994, 0.2010], dtype=torch.float32).cuda()
    normalizer = NormalizeByChannelMeanStd(mean=mean, std=std)

    if model_file:
        state_dict = torch.load(model_file)
        if train:
            print('Loading pre-trained models...')
    
    iter_m = state_dict.keys() if model_file else range(args.model_num)

    for i in iter_m:
        if args.arch.lower() == 'resnet':
            model = ResNet(depth=args.depth, leaky_relu=leaky_relu)
        else:
            raise ValueError('[{:s}] architecture is not supported yet...')
        # we include input normalization as a part of the model
        model = ModelWrapper(model, normalizer)
        if model_file:
            model.load_state_dict(state_dict[i])
        if train:
            model.train()
        else:
            model.eval()
        model = model.cuda()
        models.append(model)

    if as_ensemble:
        assert not train, 'Must be in eval mode when getting models to form an ensemble'
        ensemble = Ensemble(models)
        ensemble.eval()
        return ensemble
    else:
        return models
示例#6
0
    def train(self, epoch):
        for m in self.models:
            m.train()

        losses = [0 for i in range(len(self.models))]
        
        batch_iter = self.get_batch_iterator()
        for batch_idx, (inputs, targets) in enumerate(batch_iter):
            inputs, targets = inputs.cuda(), targets.cuda()

            ensemble = Ensemble(self.models)
            adv_inputs = Linf_PGD(ensemble, inputs, targets, **self.attack_cfg)

            for i, m in enumerate(self.models):
                loss = 0

                outputs = m(adv_inputs)
                loss = self.criterion(outputs, targets)
                losses[i] += loss.item()

                self.optimizers[i].zero_grad()
                loss.backward()
                self.optimizers[i].step()            

        print_message = 'Epoch [%3d] | ' % epoch
        for i in range(len(self.models)):
            print_message += 'Model{i:d}: {loss:.4f}  '.format(
                i=i+1, loss=losses[i]/(batch_idx+1))
        tqdm.write(print_message)

        for i in range(len(self.models)):
            self.schedulers[i].step()

        loss_dict = {}
        for i in range(len(self.models)):
            loss_dict[str(i)] = losses[i]/len(self.trainloader)
        self.writer.add_scalars('train/adv_loss', loss_dict, epoch)
示例#7
0
def test_ensemble(path_model_list,
                  model_name,
                  task_set_whole_list,
                  test_task_set,
                  test_batch_size,
                  steps,
                  debug,
                  epsilon,
                  step_size,
                  dataset="taskonomy",
                  default_suffix="/savecheckpoint/checkpoint_150.pth.tar",
                  use_noise=True,
                  momentum=False,
                  use_houdini=False):

    print('task_set_whole_list', task_set_whole_list)
    print('test_task_set', test_task_set)

    for i, each in enumerate(path_model_list):
        path_model_list[i] = each + default_suffix

    parser = argparse.ArgumentParser(
        description='Run Experiments with Checkpoint Models')
    args = parser.parse_args()

    args.dataset = dataset
    args.arch = model_name
    args.use_noise = use_noise
    args.momentum = momentum

    import socket, json
    config_file_path = "config/{}_{}_config.json".format(
        args.arch, args.dataset)
    with open(config_file_path) as config_file:
        config = json.load(config_file)
    if socket.gethostname() == "deep":
        args.data_dir = config['data-dir_deep']
    elif socket.gethostname() == 'hulk':
        args.data_dir = '/local/rcs/ECCV/Cityscape/cityscape_dataset'
    else:
        args.data_dir = config['data-dir']

    args.task_set = task_set_whole_list
    args.test_task_set = test_task_set
    args.test_batch_size = test_batch_size
    args.classes = config['classes']
    args.workers = config['workers']
    args.pixel_scale = config['pixel_scale']
    args.steps = steps
    args.debug = debug

    args.epsilon = epsilon
    args.step_size = step_size

    # ADDED FOR CITYSCAPES
    args.random_scale = config['random-scale']
    args.random_rotate = config['random-rotate']
    args.crop_size = config['crop-size']
    args.list_dir = config['list-dir']

    num_being_tested = len(test_task_set)

    print("PRINTING ARGUMENTS \n")
    for k, v in args.__dict__.items(
    ):  # Prints arguments and contents of config file
        print(k, ':', v)

    dict_args = vars(args)
    dict_summary = {}
    dict_summary['config'] = dict_args
    dict_summary['results'] = {}
    dict_model_summary = {}

    model_list = []
    criteria_list = []
    task_list_set = []
    for each, path_model in zip(task_set_whole_list, path_model_list):
        model = get_submodel_ensemble(model_name, args, each)
        if torch.cuda.is_available():
            model.cuda()

        print("=> Loading checkpoint '{}'".format(path_model))
        if torch.cuda.is_available():
            checkpoint_model = torch.load(path_model)
        else:
            checkpoint_model = torch.load(
                path_model, map_location=lambda storage, loc: storage)
        model.load_state_dict(checkpoint_model['state_dict'])  # , strict=False
        model_list.append(model)

        from models.mtask_losses import get_losses_and_tasks

        print("each ", each)
        criteria, taskonomy_tasks = get_losses_and_tasks(
            args, customized_task_set=each)
        print("criteria got", criteria)
        criteria_list.append(criteria)
        task_list_set.extend(taskonomy_tasks)

    task_list_set = list(set(task_list_set))
    # print('dataloader will load these tasks', task_list_set)
    val_loader = get_loader(args,
                            split='val',
                            out_name=False,
                            customized_task_set=task_list_set)

    from models.ensemble import Ensemble
    model_whole = Ensemble(model_list)

    from learning.test_ensemble import mtask_ensemble_test

    # mtask_ensemble_test(val_loader, model_ensemble, criterion_list, task_name, args, info)

    # print('mid test task', args.test_task_set)
    advacc_result = mtask_ensemble_test(val_loader,
                                        model_whole,
                                        criteria_list,
                                        args.test_task_set,
                                        args,
                                        use_houdini=use_houdini)
    print(
        "Results: epsilon {} step {} step_size {}  Acc for task {} ::".format(
            args.epsilon, args.steps, args.step_size, args.test_task_set),
        advacc_result)
示例#8
0
def main(ens_opt):
    # setup gpu
    try:
        gpu_id = int(subprocess.check_output('gpu_getIDs.sh', shell=True))
    except:
        print("Failed to get gpu_id (setting gpu_id to %d)" % ens_opt.gpu_id)
        gpu_id = str(ens_opt.gpu_id)
        # beware seg fault if tf after torch!!
    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
    ens_opt.logger.warn('GPU ID: %s | available memory: %dM' \
                        % (os.environ['CUDA_VISIBLE_DEVICES'], get_gpu_memory(gpu_id)))
    import tensorflow as tf
    import torch
    import models.setup as ms
    from models.ensemble import Ensemble, eval_ensemble, eval_external_ensemble
    import models.cnn as cnn
    from loader import DataLoader, DataLoaderRaw

    ens_opt.models = [_[0] for _ in ens_opt.models]
    print('Models:', ens_opt.models)
    if not ens_opt.output:
        if not len(ens_opt.image_folder):
            evaldir = '%s/evaluations/%s' % (ens_opt.ensemblename,
                                             ens_opt.split)
        else:
            ens_opt.split = ens_opt.image_list.split('/')[-1].split('.')[0]
            print('Split :: ', ens_opt.split)
            evaldir = '%s/evaluations/server_%s' % (ens_opt.ensemblename,
                                                    ens_opt.split)

        if not osp.exists(evaldir):
            os.makedirs(evaldir)
        ens_opt.output = '%s/bw%d' % (evaldir, ens_opt.beam_size)
    models_paths = []
    cnn_models = []
    rnn_models = []
    options = []
    # Reformat:
    for m in ens_opt.models:
        models_paths.append('save/%s/model-best.pth' %
                            m)  # FIXME check that cnn-best is the one loaded
        infos_path = "save/%s/infos-best.pkl" % m
        with open(infos_path, 'rb') as f:
            print('Opening %s' % infos_path)
            infos = pickle.load(f, encoding="iso-8859-1")
        vocab = infos['vocab']
        iopt = infos['opt']
        # define single model options
        params = copy.copy(vars(ens_opt))
        params.update(vars(iopt))
        opt = argparse.Namespace(**params)
        opt.modelname = 'save/' + m
        opt.start_from_best = ens_opt.start_from_best
        opt.beam_size = ens_opt.beam_size
        opt.batch_size = ens_opt.batch_size
        opt.logger = ens_opt.logger
        if opt.start_from_best:
            flag = '-best'
            opt.logger.warn('Starting from the best saved model')
        else:
            flag = ''
        opt.cnn_start_from = osp.join(opt.modelname, 'model-cnn%s.pth' % flag)
        opt.infos_start_from = osp.join(opt.modelname, 'infos%s.pkl' % flag)
        opt.start_from = osp.join(opt.modelname, 'model%s.pth' % flag)
        opt.logger.warn('Starting from %s' % opt.start_from)

        # Load infos
        with open(opt.infos_start_from, 'rb') as f:
            print('Opening %s' % opt.infos_start_from)
            infos = pickle.load(f, encoding="iso-8859-1")
            infos['opt'].logger = None
        ignore = [
            "batch_size", "beam_size", "start_from", 'cnn_start_from',
            'infos_start_from', "start_from_best", "language_eval", "logger",
            "val_images_use", 'input_data', "loss_version", "region_size",
            "use_adaptive_pooling", "clip_reward", "gpu_id", "max_epochs",
            "modelname", "config", "sample_max", "temperature"
        ]
        for k in list(vars(infos['opt']).keys()):
            if k not in ignore and "learning" not in k:
                if k in vars(opt):
                    assert vars(opt)[k] == vars(
                        infos['opt'])[k], (k + ' option not consistent ' +
                                           str(vars(opt)[k]) + ' vs. ' +
                                           str(vars(infos['opt'])[k]))
                else:
                    vars(opt).update({k: vars(infos['opt'])[k]
                                      })  # copy over options from model

        opt.fliplr = 0
        opt.language_creativity = 0
        opt.seq_per_img = 5
        opt.bootstrap = 0
        opt.sample_cap = 0
        vocab = infos['vocab']  # ix -> word mapping
        # Build CNN model for single branch use
        if opt.cnn_model.startswith('resnet'):
            cnn_model = cnn.ResNetModel(opt)
        elif opt.cnn_model.startswith('vgg'):
            cnn_model = cnn.VggNetModel(opt)
        else:
            print('Unknown model %s' % opt.cnn_model)
            sys.exit(1)

        cnn_model.cuda()
        cnn_model.eval()
        model = ms.select_model(opt)
        model.load()
        model.cuda()
        model.eval()
        options.append(opt)
        cnn_models.append(cnn_model)
        rnn_models.append(model)

        # Create the Data Loader instance
    start = time.time()
    external = False
    if len(ens_opt.image_folder) == 0:
        loader = DataLoader(options[0])
    else:
        external = True
        loader = DataLoaderRaw({
            'folder_path': ens_opt.image_folder,
            'files_list': ens_opt.image_list,
            'batch_size': ens_opt.batch_size
        })
        loader.ix_to_word = vocab

    # Define the ensemble:
    ens_model = Ensemble(rnn_models, cnn_models, ens_opt)

    if external:
        preds = eval_external_ensemble(ens_model, loader, vars(ens_opt))
    else:
        preds, lang_stats = eval_ensemble(ens_model, loader, vars(ens_opt))
    print("Finished evaluation in ", (time.time() - start))
    if ens_opt.dump_json == 1:
        # dump the json
        json.dump(preds, open(ens_opt.output + ".json", 'w'))
示例#9
0
def main():
    args = argparser.parse_args()
    print("Arguments:")
    for arg in vars(args):
        print("  {}: {}".format(arg, getattr(args, arg)))
    print()

    input_dir = args.input_dir
    output_dir = args.output_dir
    base_model_dir = args.base_model_dir
    image_size = args.image_size
    augment = args.augment
    use_dummy_image = args.use_dummy_image
    use_progressive_image_sizes = args.use_progressive_image_sizes
    progressive_image_size_min = args.progressive_image_size_min
    progressive_image_size_step = args.progressive_image_size_step
    progressive_image_epoch_step = args.progressive_image_epoch_step
    batch_size = args.batch_size
    batch_iterations = args.batch_iterations
    test_size = args.test_size
    train_on_unrecognized = args.train_on_unrecognized
    num_category_shards = args.num_category_shards
    category_shard = args.category_shard
    exclude_categories = args.exclude_categories
    eval_train_mapk = args.eval_train_mapk
    mapk_topk = args.mapk_topk
    num_shard_preload = args.num_shard_preload
    num_shard_loaders = args.num_shard_loaders
    num_workers = args.num_workers
    pin_memory = args.pin_memory
    epochs_to_train = args.epochs
    lr_scheduler_type = args.lr_scheduler
    lr_patience = args.lr_patience
    lr_min = args.lr_min
    lr_max = args.lr_max
    lr_min_decay = args.lr_min_decay
    lr_max_decay = args.lr_max_decay
    optimizer_type = args.optimizer
    loss_type = args.loss
    loss2_type = args.loss2
    loss2_start_sgdr_cycle = args.loss2_start_sgdr_cycle
    model_type = args.model
    patience = args.patience
    sgdr_cycle_epochs = args.sgdr_cycle_epochs
    sgdr_cycle_epochs_mult = args.sgdr_cycle_epochs_mult
    sgdr_cycle_end_prolongation = args.sgdr_cycle_end_prolongation
    sgdr_cycle_end_patience = args.sgdr_cycle_end_patience
    max_sgdr_cycles = args.max_sgdr_cycles

    use_extended_stroke_channels = model_type in ["cnn", "residual_cnn", "fc_cnn", "hc_fc_cnn"]

    base_model_dirs = [
        "/storage/models/quickdraw/l1",
        "/storage/models/quickdraw/l2",
        "/storage/models/quickdraw/l3",
        "/storage/models/quickdraw/l4"
    ]

    model_categories = [
        ['vase', 'flip flops', 'hospital', 'lollipop', 'hammer', 'toothbrush', 'fork', 'moustache', 'sailboat', 'couch', 'underwear', 'church', 'tooth', 'penguin', 'apple', 'bulldozer', 'drums', 'kangaroo', 'alarm clock', 'submarine', 'spider', 'owl', 'stethoscope', 'mushroom', 'popsicle', 'airplane', 'flamingo', 'backpack', 'hot air balloon', 'toilet', 'candle', 'palm tree', 'camera', 'sock', 'power outlet', 'teapot', 'computer', 'triangle', 'diamond', 'snowflake', 'donut', 'compass', 'stitches', 'eyeglasses', 'paper clip', 'carrot', 'binoculars', 'envelope', 'cactus', 'flashlight', 'sun', 'traffic light', 'television', 'crown', 'pineapple', 'strawberry', 'saw', 'bee', 'megaphone', 'squirrel', 'wristwatch', 'flower', 'fish', 'rain', 'key', 'hourglass', 'clock', 'sheep', 'tennis racquet', 'star', 'parachute', 'giraffe', 'rollerskates', 'The Mona Lisa', 'sword', 'butterfly', 'mermaid', 'wine glass', 'bowtie', 'angel', 'eye', 'stairs', 'scorpion', 'house plant', 'anvil', 'chair', 'umbrella', 'see saw', 'snail', 'The Eiffel Tower', 'ladder', 'camel', 'octopus', 'skateboard', 'harp', 'snowman', 'skull', 'swing set', 'ice cream', 'stop sign', 'headphones', 'helicopter'],
        ['banana', 'parrot', 'tree', 'lipstick', 'teddy-bear', 'horse', 'arm', 'basket', 'necklace', 'baseball bat', 'sandwich', 'zebra', 'telephone', 'elephant', 'hot dog', 'streetlight', 'shorts', 'face', 'table', 'cow', 'postcard', 'boomerang', 'pear', 'shovel', 'zigzag', 'rhinoceros', 'onion', 'picture frame', 'saxophone', 'hat', 'cruise ship', 'train', 'ceiling fan', 'nose', 'belt', 'speedboat', 'bridge', 'barn', 'door', 'skyscraper', 'fence', 'scissors', 'shark', 'rake', 'microphone', 'ear', 'whale', 'fireplace', 'lightning', 'screwdriver', 'jacket', 'crab', 'roller coaster', 'cannon', 'garden', 'helmet', 'dresser', 'bed', 'nail', 'swan', 'fan', 'bat', 'rabbit', 'mountain', 'shoe', 'floor lamp', 'soccer ball', 'mailbox', 'laptop', 'washing machine', 'drill', 'calculator', 'ant', 'chandelier', 'hamburger', 'lighthouse', 'sea turtle', 'goatee', 'pizza', 'crocodile', 'dolphin', 'rainbow', 'frying pan', 'leaf', 'mouth', 'snorkel', 'remote control', 'light bulb', 'axe', 'hand', 'pig', 'sink', 'baseball', 'lion', 'pants', 'windmill', 'castle', 'dumbbell', 'hedgehog', 'tent', 'wine bottle', 'bandage'],
        ['animal migration', 'monkey', 'watermelon', 'radio', 'panda', 'beach', 'dishwasher', 'calendar', 'peas', 'bottlecap', 'bird', 'police car', 'ambulance', 'clarinet', 'mouse', 'snake', 'asparagus', 'cloud', 'finger', 'dragon', 'foot', 'microwave', 'cookie', 'book', 'tiger', 'sleeping bag', 'canoe', 'toothpaste', 'toe', 'broom', 'tractor', 'matches', 'brain', 'bread', 'bracelet', 'purse', 'knee', 'diving board', 'peanut', 'paintbrush', 'lantern', 'firetruck', 'pliers', 'duck', 'map', 't-shirt', 'toaster', 'yoga', 'lobster', 'elbow', 'passport', 'waterslide', 'broccoli', 'moon', 'campfire', 'jail', 'basketball', 'sweater', 'fire hydrant', 'feather', 'flying saucer', 'grass', 'spoon', 'cell phone', 'smiley face', 'beard', 'wheel', 'house'],
        ['camouflage', 'mug', 'cello', 'hurricane', 'bus', 'truck', 'pond', 'birthday cake', 'garden hose', 'cake', 'school bus', 'leg', 'van', 'guitar', 'cup', 'pool', 'hockey stick', 'bear', 'marker', 'blackberry', 'squiggle', 'tornado', 'crayon', 'circle', 'pickup truck', 'coffee cup', 'cooler', 'square', 'river', 'paint can', 'oven', 'string bean', 'The Great Wall of China', 'hockey puck', 'car', 'spreadsheet', 'trombone', 'bucket', 'trumpet', 'eraser', 'line', 'pencil', 'pillow', 'blueberry', 'frog', 'bush', 'keyboard', 'steak', 'potato', 'ocean', 'bicycle', 'mosquito', 'stereo', 'dog', 'suitcase', 'violin', 'octagon', 'bathtub', 'raccoon', 'hot tub', 'cat', 'bench', 'piano', 'stove', 'golf club', 'motorbike', 'grapes', 'hexagon']
    ]

    categories = read_lines("{}/categories.txt".format(input_dir))

    test_data = TestData(input_dir)
    test_set = TestDataset(test_data.df, image_size, use_extended_stroke_channels)
    test_set_data_loader = \
        DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    all_model_predictions = []
    for base_model_dir in base_model_dirs:
        print("Processing model dir '{}'".format(base_model_dir), flush=True)

        ms = []
        for model_file_path in glob.glob("{}/model-*.pth".format(base_model_dir)):
            m = create_model(type=model_type, input_size=image_size, num_classes=len(categories)).to(device)
            m.load_state_dict(torch.load(model_file_path, map_location=device))
            ms.append(m)
        model = Ensemble(ms)

        print("Predicting...", flush=True)
        all_model_predictions.append(predict(model, test_set_data_loader, tta=True))

    print("Merging predictions...", flush=True)

    final_predictions = all_model_predictions[0].copy()
    cumulative_categories = model_categories[0].copy()
    for m in range(1, len(all_model_predictions)):
        model_predictions = all_model_predictions[m]
        for p in range(len(model_predictions)):
            final_prediction_scores = final_predictions[p][0]
            final_prediction_categories = final_predictions[p][1]
            current_prediction_scores = model_predictions[p][0]
            current_prediction_categories = model_predictions[p][1]
            for r in range(len(final_prediction_scores)):
                final_prediction_score = final_prediction_scores[r]
                final_prediction_category = final_prediction_categories[r]
                current_prediction_score = current_prediction_scores[r]
                current_prediction_category = current_prediction_categories[r]
                final_category_contained = categories[final_prediction_category] in cumulative_categories
                current_category_contained = categories[current_prediction_category] in model_categories[m]
                if final_category_contained == current_category_contained:
                    if current_prediction_score > final_prediction_score:
                        final_prediction_scores[r] = current_prediction_score
                        final_prediction_categories[r] = current_prediction_category
                else:
                    if current_category_contained:
                        final_prediction_scores[r] = current_prediction_score
                        final_prediction_categories[r] = current_prediction_category
        cumulative_categories.extend(model_categories[m])

    words = np.array([c.replace(" ", "_") for c in categories])

    submission_df = test_data.df.copy()
    submission_df["word"] = [" ".join(words[fp[1]]) for fp in final_predictions]
    submission_df.to_csv("{}/submission.csv".format(output_dir), columns=["word"])
示例#10
0
    def train(self, epoch):
        for m in self.models:
            m.train()

        losses = 0
        ce_losses = 0
        coh_losses = 0
        adv_losses = 0

        batch_iter = self.get_batch_iterator()
        for batch_idx, (inputs, targets) in enumerate(batch_iter):
            inputs, targets = inputs.cuda(), targets.cuda()
            inputs.requires_grad = True

            if self.plus_adv:
                ensemble = Ensemble(self.models)
                adv_inputs = Linf_PGD(ensemble, inputs, targets,
                                      **self.attack_cfg)

            ce_loss = 0
            adv_loss = 0
            grads = []
            for i, m in enumerate(self.models):
                # for coherence loss
                outputs = m(inputs)
                loss = self.criterion(outputs, targets)
                grad = autograd.grad(loss, inputs, create_graph=True)[0]
                grad = grad.flatten(start_dim=1)
                grads.append(grad)

                # for standard loss
                ce_loss += self.criterion(m(inputs.clone().detach()), targets)

                if self.plus_adv:
                    # for adv loss
                    adv_loss += self.criterion(m(adv_inputs), targets)

            cos_sim = []
            for i in range(len(self.models)):
                for j in range(i + 1, len(self.models)):
                    cos_sim.append(
                        F.cosine_similarity(grads[i], grads[j], dim=-1))

            cos_sim = torch.stack(cos_sim, dim=-1)
            assert cos_sim.shape == (inputs.size(0),
                                     (len(self.models) *
                                      (len(self.models) - 1)) // 2)
            coh_loss = torch.log(cos_sim.exp().sum(dim=-1) +
                                 self.log_offset).mean()

            loss = ce_loss / len(
                self.models) + self.coeff * coh_loss + adv_loss / len(
                    self.models)

            losses += loss.item()
            ce_losses += ce_loss.item()
            coh_losses += coh_loss.item()
            if self.plus_adv:
                adv_losses += adv_loss.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        self.scheduler.step()

        print_message = 'Epoch [%3d] | ce_loss: %.4f\tcoh_loss: %.4f\tadv_loss: %.4f' % (
            epoch, ce_losses / (batch_idx + 1), coh_losses /
            (batch_idx + 1), adv_losses / (batch_idx + 1))
        tqdm.write(print_message)

        self.writer.add_scalar('train/ce_loss',
                               ce_losses / len(self.trainloader), epoch)
        self.writer.add_scalar('train/coh_loss',
                               coh_losses / len(self.trainloader), epoch)
        self.writer.add_scalar('train/adv_loss',
                               adv_losses / len(self.trainloader), epoch)
示例#11
0
    def train(self, epoch):
        for m in self.models:
            m.train()

        losses = 0
        ce_losses = 0
        ee_losses = 0
        det_losses = 0
        adv_losses = 0

        batch_iter = self.get_batch_iterator()
        for batch_idx, (inputs, targets) in enumerate(batch_iter):
            inputs, targets = inputs.cuda(), targets.cuda()

            if self.plus_adv:
                ensemble = Ensemble(self.models)
                adv_inputs = Linf_PGD(ensemble, inputs, targets,
                                      **self.attack_cfg)

            # one-hot label
            num_classes = 10
            y_true = torch.zeros(inputs.size(0), num_classes).cuda()
            y_true.scatter_(1, targets.view(-1, 1), 1)

            ce_loss = 0
            adv_loss = 0
            mask_non_y_pred = []
            ensemble_probs = 0
            for i, m in enumerate(self.models):
                outputs = m(inputs)
                ce_loss += self.criterion(outputs, targets)

                # for log_det
                y_pred = F.softmax(outputs, dim=-1)
                bool_R_y_true = torch.eq(
                    torch.ones_like(y_true) - y_true, torch.ones_like(
                        y_true))  # batch_size X (num_class X num_models), 2-D
                mask_non_y_pred.append(
                    torch.masked_select(y_pred, bool_R_y_true).reshape(
                        -1, num_classes -
                        1))  # batch_size X (num_class-1) X num_models, 1-D

                # for ensemble entropy
                ensemble_probs += y_pred

                if self.plus_adv:
                    # for adv loss
                    adv_loss += self.criterion(m(adv_inputs), targets)

            ensemble_probs = ensemble_probs / len(self.models)
            ensemble_entropy = torch.sum(-torch.mul(
                ensemble_probs, torch.log(ensemble_probs + self.log_offset)),
                                         dim=-1).mean()

            mask_non_y_pred = torch.stack(mask_non_y_pred, dim=1)
            assert mask_non_y_pred.shape == (inputs.size(0), len(self.models),
                                             num_classes - 1)
            mask_non_y_pred = mask_non_y_pred / torch.norm(
                mask_non_y_pred, p=2, dim=-1,
                keepdim=True)  # batch_size X num_model X (num_class-1), 3-D
            matrix = torch.matmul(
                mask_non_y_pred, mask_non_y_pred.permute(
                    0, 2, 1))  # batch_size X num_model X num_model, 3-D
            log_det = torch.logdet(
                matrix + self.det_offset *
                torch.eye(len(self.models), device=matrix.device).unsqueeze(0)
            ).mean()  # batch_size X 1, 1-D

            loss = ce_loss - self.alpha * ensemble_entropy - self.beta * log_det + adv_loss

            losses += loss.item()
            ce_losses += ce_loss.item()
            ee_losses += ensemble_entropy.item()
            det_losses += -log_det.item()
            if self.plus_adv:
                adv_losses += adv_loss.item()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

        self.scheduler.step()

        print_message = 'Epoch [%3d] | ' % epoch
        for i in range(len(self.models)):
            print_message += 'Model{i:d}: {loss:.4f}  '.format(i=i + 1,
                                                               loss=losses /
                                                               (batch_idx + 1))
        tqdm.write(print_message)

        self.writer.add_scalar('train/ce_loss',
                               ce_losses / len(self.trainloader), epoch)
        self.writer.add_scalar('train/ce_loss',
                               ce_losses / len(self.trainloader), epoch)
        self.writer.add_scalar('train/ee_loss',
                               ee_losses / len(self.trainloader), epoch)
        self.writer.add_scalar('train/det_loss',
                               det_losses / len(self.trainloader), epoch)
        self.writer.add_scalar('train/adv_loss',
                               adv_losses / len(self.trainloader), epoch)
示例#12
0
文件: main.py 项目: rjk2147/MPC
    from envs.pinkpanther import PinkPantherEnv
    import pickle as pkl
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    env = PinkPantherEnv(render=False)
    env = rewWrapper(env)

    # rew_fn = lambda x: x[0] - 0.5*x[1]

    state_dim, act_dim = env.observation_space.shape[
        0], env.action_space.shape[0]
    train_steps = np.arange(101) * 100
    n_trials, n_runs = 20, 5
    results, losses = dict(), dict()
    pos_paths = []
    print('Starting')
    models = [Ensemble(state_dim, act_dim) for _ in range(n_runs)]
    datas = [[] for i in range(n_runs)]
    for i in range(len(train_steps)):
        start = time.time()
        results[train_steps[i]] = []
        losses[train_steps[i]] = []
        for j in range(n_runs):
            pred_rs, real_rs = [], []
            if train_steps[i] > 0:
                models[j], loss, datas[j] = train_model(
                    100, env, models[j], datas[j])
                losses[train_steps[i]].append(loss)
                pkl.dump(losses, open('losses.pkl', 'wb+'))
            planner = CEM(modelSim(models[j]), env.action_space, nsteps=nsteps)
            print(
                str(train_steps[i]) + ' Model trained in ' +
示例#13
0
def main():
    args = argparser.parse_args()
    print("Arguments:")
    for arg in vars(args):
        print("  {}: {}".format(arg, getattr(args, arg)))
    print()

    input_dir = args.input_dir
    output_dir = args.output_dir
    base_model_dir = args.base_model_dir
    image_size = args.image_size
    augment = args.augment
    use_dummy_image = args.use_dummy_image
    use_progressive_image_sizes = args.use_progressive_image_sizes
    progressive_image_size_min = args.progressive_image_size_min
    progressive_image_size_step = args.progressive_image_size_step
    progressive_image_epoch_step = args.progressive_image_epoch_step
    batch_size = args.batch_size
    batch_iterations = args.batch_iterations
    test_size = args.test_size
    train_on_val = args.train_on_val
    fold = args.fold
    train_on_unrecognized = args.train_on_unrecognized
    confusion_set = args.confusion_set
    num_category_shards = args.num_category_shards
    category_shard = args.category_shard
    eval_train_mapk = args.eval_train_mapk
    mapk_topk = args.mapk_topk
    num_shard_preload = args.num_shard_preload
    num_shard_loaders = args.num_shard_loaders
    num_workers = args.num_workers
    pin_memory = args.pin_memory
    epochs_to_train = args.epochs
    lr_scheduler_type = args.lr_scheduler
    lr_patience = args.lr_patience
    lr_min = args.lr_min
    lr_max = args.lr_max
    lr_min_decay = args.lr_min_decay
    lr_max_decay = args.lr_max_decay
    optimizer_type = args.optimizer
    loss_type = args.loss
    bootstraping_loss_ratio = args.bootstraping_loss_ratio
    loss2_type = args.loss2
    loss2_start_sgdr_cycle = args.loss2_start_sgdr_cycle
    model_type = args.model
    patience = args.patience
    sgdr_cycle_epochs = args.sgdr_cycle_epochs
    sgdr_cycle_epochs_mult = args.sgdr_cycle_epochs_mult
    sgdr_cycle_end_prolongation = args.sgdr_cycle_end_prolongation
    sgdr_cycle_end_patience = args.sgdr_cycle_end_patience
    max_sgdr_cycles = args.max_sgdr_cycles

    use_extended_stroke_channels = model_type in ["cnn", "residual_cnn", "fc_cnn", "hc_fc_cnn"]
    print("use_extended_stroke_channels: {}".format(use_extended_stroke_channels), flush=True)

    progressive_image_sizes = list(range(progressive_image_size_min, image_size + 1, progressive_image_size_step))

    train_data_provider = TrainDataProvider(
        input_dir,
        50,
        num_shard_preload=num_shard_preload,
        num_workers=num_shard_loaders,
        test_size=test_size,
        fold=fold,
        train_on_unrecognized=train_on_unrecognized,
        confusion_set=confusion_set,
        num_category_shards=num_category_shards,
        category_shard=category_shard,
        train_on_val=train_on_val)

    train_data = train_data_provider.get_next()

    train_set = TrainDataset(train_data.train_set_df, len(train_data.categories), image_size, use_extended_stroke_channels, augment, use_dummy_image)
    stratified_sampler = StratifiedSampler(train_data.train_set_df["category"], batch_size * batch_iterations)
    train_set_data_loader = \
        DataLoader(train_set, batch_size=batch_size, shuffle=False, sampler=stratified_sampler, num_workers=num_workers,
                   pin_memory=pin_memory)

    val_set = TrainDataset(train_data.val_set_df, len(train_data.categories), image_size, use_extended_stroke_channels, False, use_dummy_image)
    val_set_data_loader = \
        DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    if base_model_dir:
        for base_file_path in glob.glob("{}/*.pth".format(base_model_dir)):
            shutil.copyfile(base_file_path, "{}/{}".format(output_dir, os.path.basename(base_file_path)))
        model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
        model.load_state_dict(torch.load("{}/model.pth".format(output_dir), map_location=device))
        optimizer = create_optimizer(optimizer_type, model, lr_max)
        if os.path.isfile("{}/optimizer.pth".format(output_dir)):
            optimizer.load_state_dict(torch.load("{}/optimizer.pth".format(output_dir)))
            adjust_initial_learning_rate(optimizer, lr_max)
            adjust_learning_rate(optimizer, lr_max)
    else:
        model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
        optimizer = create_optimizer(optimizer_type, model, lr_max)

    torch.save(model.state_dict(), "{}/model.pth".format(output_dir))

    ensemble_model_index = 0
    for model_file_path in glob.glob("{}/model-*.pth".format(output_dir)):
        model_file_name = os.path.basename(model_file_path)
        model_index = int(model_file_name.replace("model-", "").replace(".pth", ""))
        ensemble_model_index = max(ensemble_model_index, model_index + 1)

    if confusion_set is not None:
        shutil.copyfile(
            "/storage/models/quickdraw/seresnext50_confusion/confusion_set_{}.txt".format(confusion_set),
            "{}/confusion_set.txt".format(output_dir))

    epoch_iterations = ceil(len(train_set) / batch_size)

    print("train_set_samples: {}, val_set_samples: {}".format(len(train_set), len(val_set)), flush=True)
    print()

    global_val_mapk_best_avg = float("-inf")
    sgdr_cycle_val_mapk_best_avg = float("-inf")

    lr_scheduler = CosineAnnealingLR(optimizer, T_max=sgdr_cycle_epochs, eta_min=lr_min)

    optim_summary_writer = SummaryWriter(log_dir="{}/logs/optim".format(output_dir))
    train_summary_writer = SummaryWriter(log_dir="{}/logs/train".format(output_dir))
    val_summary_writer = SummaryWriter(log_dir="{}/logs/val".format(output_dir))

    current_sgdr_cycle_epochs = sgdr_cycle_epochs
    sgdr_next_cycle_end_epoch = current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation
    sgdr_iterations = 0
    sgdr_cycle_count = 0
    batch_count = 0
    epoch_of_last_improval = 0

    lr_scheduler_plateau = ReduceLROnPlateau(optimizer, mode="max", min_lr=lr_min, patience=lr_patience, factor=0.8, threshold=1e-4)

    print('{"chart": "best_val_mapk", "axis": "epoch"}')
    print('{"chart": "val_mapk", "axis": "epoch"}')
    print('{"chart": "val_loss", "axis": "epoch"}')
    print('{"chart": "val_accuracy@1", "axis": "epoch"}')
    print('{"chart": "val_accuracy@3", "axis": "epoch"}')
    print('{"chart": "val_accuracy@5", "axis": "epoch"}')
    print('{"chart": "val_accuracy@10", "axis": "epoch"}')
    print('{"chart": "sgdr_cycle", "axis": "epoch"}')
    print('{"chart": "mapk", "axis": "epoch"}')
    print('{"chart": "loss", "axis": "epoch"}')
    print('{"chart": "lr_scaled", "axis": "epoch"}')
    print('{"chart": "mem_used", "axis": "epoch"}')
    print('{"chart": "epoch_time", "axis": "epoch"}')

    train_start_time = time.time()

    criterion = create_criterion(loss_type, len(train_data.categories), bootstraping_loss_ratio)

    if loss_type == "center":
        optimizer_centloss = torch.optim.SGD(criterion.center.parameters(), lr=0.01)

    for epoch in range(epochs_to_train):
        epoch_start_time = time.time()

        print("memory used: {:.2f} GB".format(psutil.virtual_memory().used / 2 ** 30), flush=True)

        if use_progressive_image_sizes:
            next_image_size = \
                progressive_image_sizes[min(epoch // progressive_image_epoch_step, len(progressive_image_sizes) - 1)]

            if train_set.image_size != next_image_size:
                print("changing image size to {}".format(next_image_size), flush=True)
                train_set.image_size = next_image_size
                val_set.image_size = next_image_size

        model.train()

        train_loss_sum_t = zero_item_tensor()
        train_mapk_sum_t = zero_item_tensor()

        epoch_batch_iter_count = 0

        for b, batch in enumerate(train_set_data_loader):
            images, categories, categories_one_hot = \
                batch[0].to(device, non_blocking=True), \
                batch[1].to(device, non_blocking=True), \
                batch[2].to(device, non_blocking=True)

            if lr_scheduler_type == "cosine_annealing":
                lr_scheduler.step(epoch=min(current_sgdr_cycle_epochs, sgdr_iterations / epoch_iterations))

            if b % batch_iterations == 0:
                optimizer.zero_grad()

            prediction_logits = model(images)
            # if prediction_logits.size(1) == len(class_weights):
            #     criterion.weight = class_weights
            loss = criterion(prediction_logits, get_loss_target(criterion, categories, categories_one_hot))
            loss.backward()

            with torch.no_grad():
                train_loss_sum_t += loss
                if eval_train_mapk:
                    train_mapk_sum_t += mapk(prediction_logits, categories,
                                             topk=min(mapk_topk, len(train_data.categories)))

            if (b + 1) % batch_iterations == 0 or (b + 1) == len(train_set_data_loader):
                optimizer.step()
                if loss_type == "center":
                    for param in criterion.center.parameters():
                        param.grad.data *= (1. / 0.5)
                    optimizer_centloss.step()

            sgdr_iterations += 1
            batch_count += 1
            epoch_batch_iter_count += 1

            optim_summary_writer.add_scalar("lr", get_learning_rate(optimizer), batch_count + 1)

        # TODO: recalculate epoch_iterations and maybe other values?
        train_data = train_data_provider.get_next()
        train_set.df = train_data.train_set_df
        val_set.df = train_data.val_set_df
        epoch_iterations = ceil(len(train_set) / batch_size)
        stratified_sampler.class_vector = train_data.train_set_df["category"]

        train_loss_avg = train_loss_sum_t.item() / epoch_batch_iter_count
        train_mapk_avg = train_mapk_sum_t.item() / epoch_batch_iter_count

        val_loss_avg, val_mapk_avg, val_accuracy_top1_avg, val_accuracy_top3_avg, val_accuracy_top5_avg, val_accuracy_top10_avg = \
            evaluate(model, val_set_data_loader, criterion, mapk_topk)

        if lr_scheduler_type == "reduce_on_plateau":
            lr_scheduler_plateau.step(val_mapk_avg)

        model_improved_within_sgdr_cycle = check_model_improved(sgdr_cycle_val_mapk_best_avg, val_mapk_avg)
        if model_improved_within_sgdr_cycle:
            torch.save(model.state_dict(), "{}/model-{}.pth".format(output_dir, ensemble_model_index))
            sgdr_cycle_val_mapk_best_avg = val_mapk_avg

        model_improved = check_model_improved(global_val_mapk_best_avg, val_mapk_avg)
        ckpt_saved = False
        if model_improved:
            torch.save(model.state_dict(), "{}/model.pth".format(output_dir))
            torch.save(optimizer.state_dict(), "{}/optimizer.pth".format(output_dir))
            global_val_mapk_best_avg = val_mapk_avg
            epoch_of_last_improval = epoch
            ckpt_saved = True

        sgdr_reset = False
        if (lr_scheduler_type == "cosine_annealing") and (epoch + 1 >= sgdr_next_cycle_end_epoch) and (epoch - epoch_of_last_improval >= sgdr_cycle_end_patience):
            sgdr_iterations = 0
            current_sgdr_cycle_epochs = int(current_sgdr_cycle_epochs * sgdr_cycle_epochs_mult)
            sgdr_next_cycle_end_epoch = epoch + 1 + current_sgdr_cycle_epochs + sgdr_cycle_end_prolongation

            ensemble_model_index += 1
            sgdr_cycle_val_mapk_best_avg = float("-inf")
            sgdr_cycle_count += 1
            sgdr_reset = True

            new_lr_min = lr_min * (lr_min_decay ** sgdr_cycle_count)
            new_lr_max = lr_max * (lr_max_decay ** sgdr_cycle_count)
            new_lr_max = max(new_lr_max, new_lr_min)

            adjust_learning_rate(optimizer, new_lr_max)
            lr_scheduler = CosineAnnealingLR(optimizer, T_max=current_sgdr_cycle_epochs, eta_min=new_lr_min)
            if loss2_type is not None and sgdr_cycle_count >= loss2_start_sgdr_cycle:
                print("switching to loss type '{}'".format(loss2_type), flush=True)
                criterion = create_criterion(loss2_type, len(train_data.categories), bootstraping_loss_ratio)

        optim_summary_writer.add_scalar("sgdr_cycle", sgdr_cycle_count, epoch + 1)

        train_summary_writer.add_scalar("loss", train_loss_avg, epoch + 1)
        train_summary_writer.add_scalar("mapk", train_mapk_avg, epoch + 1)
        val_summary_writer.add_scalar("loss", val_loss_avg, epoch + 1)
        val_summary_writer.add_scalar("mapk", val_mapk_avg, epoch + 1)

        epoch_end_time = time.time()
        epoch_duration_time = epoch_end_time - epoch_start_time

        print(
            "[%03d/%03d] %ds, lr: %.6f, loss: %.4f, val_loss: %.4f, acc: %.4f, val_acc: %.4f, ckpt: %d, rst: %d" % (
                epoch + 1,
                epochs_to_train,
                epoch_duration_time,
                get_learning_rate(optimizer),
                train_loss_avg,
                val_loss_avg,
                train_mapk_avg,
                val_mapk_avg,
                int(ckpt_saved),
                int(sgdr_reset)))

        print('{"chart": "best_val_mapk", "x": %d, "y": %.4f}' % (epoch + 1, global_val_mapk_best_avg))
        print('{"chart": "val_loss", "x": %d, "y": %.4f}' % (epoch + 1, val_loss_avg))
        print('{"chart": "val_mapk", "x": %d, "y": %.4f}' % (epoch + 1, val_mapk_avg))
        print('{"chart": "val_accuracy@1", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top1_avg))
        print('{"chart": "val_accuracy@3", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top3_avg))
        print('{"chart": "val_accuracy@5", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top5_avg))
        print('{"chart": "val_accuracy@10", "x": %d, "y": %.4f}' % (epoch + 1, val_accuracy_top10_avg))
        print('{"chart": "sgdr_cycle", "x": %d, "y": %d}' % (epoch + 1, sgdr_cycle_count))
        print('{"chart": "loss", "x": %d, "y": %.4f}' % (epoch + 1, train_loss_avg))
        print('{"chart": "mapk", "x": %d, "y": %.4f}' % (epoch + 1, train_mapk_avg))
        print('{"chart": "lr_scaled", "x": %d, "y": %.4f}' % (epoch + 1, 1000 * get_learning_rate(optimizer)))
        print('{"chart": "mem_used", "x": %d, "y": %.2f}' % (epoch + 1, psutil.virtual_memory().used / 2 ** 30))
        print('{"chart": "epoch_time", "x": %d, "y": %d}' % (epoch + 1, epoch_duration_time))

        sys.stdout.flush()

        if (sgdr_reset or lr_scheduler_type == "reduce_on_plateau") and epoch - epoch_of_last_improval >= patience:
            print("early abort due to lack of improval", flush=True)
            break

        if max_sgdr_cycles is not None and sgdr_cycle_count >= max_sgdr_cycles:
            print("early abort due to maximum number of sgdr cycles reached", flush=True)
            break

    optim_summary_writer.close()
    train_summary_writer.close()
    val_summary_writer.close()

    train_end_time = time.time()
    print()
    print("Train time: %s" % str(datetime.timedelta(seconds=train_end_time - train_start_time)), flush=True)

    if False:
        swa_model = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(
            device)
        swa_update_count = 0
        for f in find_sorted_model_files(output_dir):
            print("merging model '{}' into swa model".format(f), flush=True)
            m = create_model(type=model_type, input_size=image_size, num_classes=len(train_data.categories)).to(device)
            m.load_state_dict(torch.load(f, map_location=device))
            swa_update_count += 1
            moving_average(swa_model, m, 1.0 / swa_update_count)
            # bn_update(train_set_data_loader, swa_model)
        torch.save(swa_model.state_dict(), "{}/swa_model.pth".format(output_dir))

    test_data = TestData(input_dir)
    test_set = TestDataset(test_data.df, image_size, use_extended_stroke_channels)
    test_set_data_loader = \
        DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    model.load_state_dict(torch.load("{}/model.pth".format(output_dir), map_location=device))
    model = Ensemble([model])

    categories = train_data.categories

    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=False)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission.csv".format(output_dir), columns=["word"])

    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=True)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions_tta.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission_tta.csv".format(output_dir), columns=["word"])

    val_set_data_loader = \
        DataLoader(val_set, batch_size=64, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

    model = load_ensemble_model(output_dir, 3, val_set_data_loader, criterion, model_type, image_size, len(categories))
    submission_df = test_data.df.copy()
    predictions, predicted_words = predict(model, test_set_data_loader, categories, tta=True)
    submission_df["word"] = predicted_words
    np.save("{}/submission_predictions_ensemble_tta.npy".format(output_dir), np.array(predictions))
    submission_df.to_csv("{}/submission_ensemble_tta.csv".format(output_dir), columns=["word"])

    confusion, _ = calculate_confusion(model, val_set_data_loader, len(categories))
    precisions = np.array([confusion[c, c] for c in range(confusion.shape[0])])
    percentiles = np.percentile(precisions, q=np.linspace(0, 100, 10))

    print()
    print("Category precision percentiles:")
    print(percentiles)

    print()
    print("Categories sorted by precision:")
    print(np.array(categories)[np.argsort(precisions)])