def train(
    model: TEDD1104,
    optimizer_name: str,
    optimizer: torch.optim,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    initial_epoch: int,
    num_epoch: int,
    max_acc: float,
    hide_map_prob: float,
    num_load_files_training: int,
    fp16: bool = True,
    amp_opt_level=None,
    save_checkpoints: bool = True,
    save_every: int = 100,
    save_best: bool = True,
):
    """
    Train a model

    Input:
    - model: TEDD1104 model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (black square) from the image (0<=hide_map_prob<=1)
    - fp16: Use FP16 for training
    - amp_opt_level: If FP16 training Nvidia apex opt level
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss()
    print("Loading dev set")
    X_dev, y_dev = load_dataset(dev_dir, fp=16 if fp16 else 32)
    X_dev = torch.from_numpy(X_dev)
    print("Loading test set")
    X_test, y_test = load_dataset(test_dir, fp=16 if fp16 else 32)
    X_test = torch.from_numpy(X_test)

    acc_dev: float = 0.0
    total_training_exampels: int = 0
    printTrace("Training...")
    for epoch in range(num_epoch):
        iteration_no = 0
        num_used_files: int = 0
        files: List[str] = glob.glob(os.path.join(train_dir, "*.npz"))
        random.shuffle(files)
        # Get files in batches, all files will be loaded and data will be shuffled
        for paths in batch(files, num_load_files_training):
            iteration_no += 1
            num_used_files += num_load_files_training
            model.train()
            start_time: float = time.time()

            X, y = load_and_shuffle_datasets(paths=paths,
                                             fp=16 if fp16 else 32,
                                             hide_map_prob=hide_map_prob)
            total_training_exampels += len(y)
            running_loss = 0.0
            num_batchs = 0

            for X_bacth, y_batch in nn_batchs(X, y, batch_size):
                X_bacth, y_batch = (
                    torch.from_numpy(X_bacth).to(device),
                    torch.from_numpy(y_batch).long().to(device),
                )
                optimizer.zero_grad()
                outputs = model.forward(X_bacth)
                loss = criterion(outputs, y_batch)
                if fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), 1.0)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                optimizer.step()
                running_loss += loss.item()
                num_batchs += 1
            start_time_eval: float = time.time()
            # Print Statistics
            if len(X) > 0 and len(y) > 0:
                acc_train = evaluate(
                    model=model,
                    X=torch.from_numpy(X),
                    golds=y,
                    device=device,
                    batch_size=batch_size,
                )
            else:
                acc_train = -1.0

            acc_dev = evaluate(
                model=model,
                X=X_dev,
                golds=y_dev,
                device=device,
                batch_size=batch_size,
            )

            acc_test = evaluate(
                model=model,
                X=X_test,
                golds=y_test,
                device=device,
                batch_size=batch_size,
            )

            printTrace(
                f"EPOCH: {initial_epoch+epoch}. Iteration {iteration_no}. "
                f"{num_used_files} of {len(files)} files. "
                f"Total examples used for training {total_training_exampels}. "
                f"Iteration time: {time.time() - start_time} secs. Eval time: {time.time() - start_time_eval} secs."
            )

            printTrace(
                f"Loss: {-1 if num_batchs == 0 else running_loss / num_batchs}. Acc training set: {acc_train}. "
                f"Acc dev set: {acc_dev}. Acc test set: {acc_test}")

            if acc_dev > max_acc and save_best:
                max_acc = acc_dev
                printTrace(
                    f"New max acc in dev set {max_acc}. Saving model...")
                save_model(
                    model=model,
                    save_dir=output_dir,
                    fp16=fp16,
                    amp_opt_level=amp_opt_level,
                )

            if save_checkpoints and iteration_no % save_every == 0:
                printTrace("Saving checkpoint...")
                save_checkpoint(
                    path=os.path.join(output_dir, "checkpoint.pt"),
                    model=model,
                    optimizer_name=optimizer_name,
                    optimizer=optimizer,
                    acc_dev=acc_dev,
                    epoch=initial_epoch + epoch,
                    fp16=fp16,
                    opt_level=amp_opt_level,
                )

    return max_acc
Example #2
0
def train(
    model: TEDD1104,
    optimizer_name: str,
    optimizer: torch.optim,
    scheduler: torch.optim.lr_scheduler,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    accumulation_steps: int,
    initial_epoch: int,
    num_epoch: int,
    max_acc: float,
    hide_map_prob: float,
    dropout_images_prob: List[float],
    num_load_files_training: int,
    fp16: bool = True,
    amp_opt_level=None,
    save_checkpoints: bool = True,
    eval_every: int = 5,
    save_every: int = 20,
    save_best: bool = True,
):

    """
    Train a model

    Input:
    - model: TEDD1104 model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (put a black square)
       from a training example (0<=hide_map_prob<=1)
    - dropout_images_prob List of 5 floats or None, probability for removing each input image during training
     (black image) from a training example (0<=dropout_images_prob<=1)
    - fp16: Use FP16 for training
    - amp_opt_level: If FP16 training Nvidia apex opt level
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """
    writer: SummaryWriter = SummaryWriter()

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss()
    print("Loading dev set")
    X_dev, y_dev = load_dataset(dev_dir, fp=16 if fp16 else 32)
    X_dev = torch.from_numpy(X_dev)
    print("Loading test set")
    X_test, y_test = load_dataset(test_dir, fp=16 if fp16 else 32)
    X_test = torch.from_numpy(X_test)
    total_training_exampels: int = 0
    model.zero_grad()


    trainLoader = DataLoader(dataset=PickleDataset(train_dir), batch_size=batch_size, shuffle=
False, num_workers=8)

    printTrace("Training...")
    iteration_no: int = 0
    for epoch in range(num_epoch):
        #step_no: int = 0
        #num_used_files: int = 0

        print('EpochNum: ' + str(epoch))
        model.train()
        start_time: float = time.time()
        running_loss: float = 0.0
        acc_dev: float = 0.0

        for num_batchs, inputs in enumerate(trainLoader):
            X_bacth = torch.reshape(inputs[0], (inputs[0].shape[0] * 5, 3, inputs[0].shape[2], inputs[0].shape[3])).to(device)
            y_batch = torch.reshape(inputs[1], (inputs[0].shape[0],)).long().to(device)
            #print(X_bacth)
            #X_bacth, y_batch = (
            #    torch.from_numpy(batch_data).to(device),
            #    torch.from_numpy(inputs[1]).long().to(device),
            #)

            outputs = model.forward(X_bacth)
            #print(outputs.size())
            #print(y_batch)
            loss = criterion(outputs, y_batch) / accumulation_steps
            running_loss += loss.item()

            if fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1.0)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            model.zero_grad()

        #scheduler.step(running_loss)

        # Print Statistics
        printTrace(
            f"Loss: {running_loss/num_batchs}. "
            f"Learning rate {optimizer.state_dict()['param_groups'][0]['lr']}"
        )

        writer.add_scalar("Loss/train", running_loss, iteration_no)

        if (iteration_no + 1) % eval_every == 0:
            start_time_eval: float = time.time()

            acc_dev: float = evaluate(
                model=model,
                X=X_dev,
                golds=y_dev,
                device=device,
                batch_size=batch_size,
            )

            acc_test: float = evaluate(
                model=model,
                X=X_test,
                golds=y_test,
                device=device,
                batch_size=batch_size,
            )

            printTrace(
                f"Acc dev set: {round(acc_dev,2)}. "
                f"Acc test set: {round(acc_test,2)}.  "
                f"Eval time: {round(time.time() - start_time_eval,2)} secs."
            )

            if 0.0 < acc_dev > max_acc and save_best:
                max_acc = acc_dev
                printTrace(
                    f"New max acc in dev set {round(max_acc,2)}. Saving model..."
                )
                save_model(
                    model=model,
                    save_dir=output_dir,
                    fp16=fp16,
                    amp_opt_level=amp_opt_level,
                )
            writer.add_scalar("Accuracy/dev", acc_dev, iteration_no)
            writer.add_scalar("Accuracy/test", acc_test, iteration_no)

        if save_checkpoints and (iteration_no + 1) % save_every == 0:
            printTrace("Saving checkpoint...")
            save_checkpoint(
                path=os.path.join(output_dir, "checkpoint.pt"),
                model=model,
                optimizer_name=optimizer_name,
                optimizer=optimizer,
                scheduler=scheduler,
                acc_dev=acc_dev,
                epoch=initial_epoch + epoch,
                fp16=fp16,
                opt_level=amp_opt_level,
            )

        iteration_no += 1

    return max_acc
def train(
    model: TEDD1104,
    optimizer_name: str,
    optimizer: torch.optim,
    scheduler: torch.optim.lr_scheduler,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    accumulation_steps: int,
    initial_epoch: int,
    num_epoch: int,
    max_acc: float,
    hide_map_prob: float,
    dropout_images_prob: List[float],
    num_load_files_training: int,
    fp16: bool = True,
    amp_opt_level=None,
    save_checkpoints: bool = True,
    eval_every: int = 5,
    save_every: int = 20,
    save_best: bool = True,
):
    """
    Train a model

    Input:
    - model: TEDD1104 model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (put a black square)
       from a training example (0<=hide_map_prob<=1)
    - dropout_images_prob List of 5 floats or None, probability for removing each input image during training
     (black image) from a training example (0<=dropout_images_prob<=1)
    - fp16: Use FP16 for training
    - amp_opt_level: If FP16 training Nvidia apex opt level
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """
    writer: SummaryWriter = SummaryWriter()

    if fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss()
    print("Loading dev set")
    X_dev, y_dev = load_dataset(dev_dir, fp=16 if fp16 else 32)
    X_dev = torch.from_numpy(X_dev)
    print("Loading test set")
    X_test, y_test = load_dataset(test_dir, fp=16 if fp16 else 32)
    X_test = torch.from_numpy(X_test)
    total_training_exampels: int = 0
    model.zero_grad()

    printTrace("Training...")
    for epoch in range(num_epoch):
        step_no: int = 0
        iteration_no: int = 0
        num_used_files: int = 0
        data_loader = DataLoaderTEDD(
            dataset_dir=train_dir,
            nfiles2load=num_load_files_training,
            hide_map_prob=hide_map_prob,
            dropout_images_prob=dropout_images_prob,
            fp=16 if fp16 else 32,
        )

        data = data_loader.get_next()
        # Get files in batches, all files will be loaded and data will be shuffled
        while data:
            X, y = data
            model.train()
            start_time: float = time.time()
            total_training_exampels += len(y)
            running_loss: float = 0.0
            num_batchs: int = 0
            acc_dev: float = 0.0

            for X_bacth, y_batch in nn_batchs(X, y, batch_size):
                X_bacth, y_batch = (
                    torch.from_numpy(X_bacth).to(device),
                    torch.from_numpy(y_batch).long().to(device),
                )

                #print(X_bacth)
                outputs = model.forward(X_bacth)
                #print(outputs.size())
                #print(y_batch)
                loss = criterion(outputs, y_batch) / accumulation_steps
                running_loss += loss.item()

                if fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                if fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), 1.0)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                if (step_no + 1) % accumulation_steps or (
                        num_used_files + 1 >
                        len(data_loader) - num_load_files_training
                        and num_batchs == math.ceil(len(y) / batch_size) - 1
                ):  # If we are in the last bach of the epoch we also want to perform gradient descent
                    optimizer.step()
                    model.zero_grad()

                num_batchs += 1
                step_no += 1

            num_used_files += num_load_files_training

            # Print Statistics
            printTrace(
                f"EPOCH: {initial_epoch+epoch}. Iteration {iteration_no}. "
                f"{num_used_files} of {len(data_loader)} files. "
                f"Total examples used for training {total_training_exampels}. "
                f"Iteration time: {round(time.time() - start_time,2)} secs.")
            printTrace(
                f"Loss: {-1 if num_batchs == 0 else running_loss / num_batchs}. "
                f"Learning rate {optimizer.state_dict()['param_groups'][0]['lr']}"
            )
            writer.add_scalar("Loss/train", running_loss / num_batchs,
                              iteration_no)

            scheduler.step(running_loss / num_batchs)

            if (iteration_no + 1) % eval_every == 0:
                start_time_eval: float = time.time()
                if len(X) > 0 and len(y) > 0:
                    acc_train: float = evaluate(
                        model=model,
                        X=torch.from_numpy(X),
                        golds=y,
                        device=device,
                        batch_size=batch_size,
                    )
                else:
                    acc_train = -1.0

                acc_dev: float = evaluate(
                    model=model,
                    X=X_dev,
                    golds=y_dev,
                    device=device,
                    batch_size=batch_size,
                )

                acc_test: float = evaluate(
                    model=model,
                    X=X_test,
                    golds=y_test,
                    device=device,
                    batch_size=batch_size,
                )

                printTrace(
                    f"Acc training set: {round(acc_train,2)}. "
                    f"Acc dev set: {round(acc_dev,2)}. "
                    f"Acc test set: {round(acc_test,2)}.  "
                    f"Eval time: {round(time.time() - start_time_eval,2)} secs."
                )

                if 0.0 < acc_dev > max_acc and save_best:
                    max_acc = acc_dev
                    printTrace(
                        f"New max acc in dev set {round(max_acc,2)}. Saving model..."
                    )
                    save_model(
                        model=model,
                        save_dir=output_dir,
                        fp16=fp16,
                        amp_opt_level=amp_opt_level,
                    )
                if acc_train > -1:
                    writer.add_scalar("Accuracy/train", acc_train,
                                      iteration_no)
                writer.add_scalar("Accuracy/dev", acc_dev, iteration_no)
                writer.add_scalar("Accuracy/test", acc_test, iteration_no)

            if save_checkpoints and (iteration_no + 1) % save_every == 0:
                printTrace("Saving checkpoint...")
                save_checkpoint(
                    path=os.path.join(output_dir, "checkpoint.pt"),
                    model=model,
                    optimizer_name=optimizer_name,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    acc_dev=acc_dev,
                    epoch=initial_epoch + epoch,
                    fp16=fp16,
                    opt_level=amp_opt_level,
                )

            iteration_no += 1
            data = data_loader.get_next()

        data_loader.close()

    return max_acc
def train(
    model: TEDD1104,
    optimizer_name: str,
    optimizer: torch.optim,
    scheduler: torch.optim.lr_scheduler,
    scaler: GradScaler,
    train_dir: str,
    dev_dir: str,
    test_dir: str,
    output_dir: str,
    batch_size: int,
    accumulation_steps: int,
    initial_epoch: int,
    num_epoch: int,
    running_loss: float,
    total_batches: int,
    total_training_examples: int,
    max_acc: float,
    hide_map_prob: float,
    dropout_images_prob: List[float],
    fp16: bool = True,
    save_checkpoints: bool = True,
    save_every: int = 20,
    save_best: bool = True,
):
    """
    Train a model

    Input:
    - model: TEDD1104 model to train
    - optimizer_name: Name of the optimizer to use [SGD, Adam]
    - optimizer: Optimizer (torch.optim)
    - train_dir: Directory where the train files are stored
    - dev_dir: Directory where the development files are stored
    - test_dir: Directory where the test files are stored
    - output_dir: Directory where the model and the checkpoints are going to be saved
    - batch_size: Batch size (Around 10 for 8GB GPU)
    - initial_epoch: Number of previous epochs used to train the model (0 unless the model has been
      restored from checkpoint)
    - num_epochs: Number of epochs to do
    - max_acc: Accuracy in the development set (0 unless the model has been
      restored from checkpoint)
    - hide_map_prob: Probability for removing the minimap (put a black square)
       from a training example (0<=hide_map_prob<=1)
    - dropout_images_prob List of 5 floats or None, probability for removing each input image during training
     (black image) from a training example (0<=dropout_images_prob<=1)
    - fp16: Use FP16 for training
    - save_checkpoints: save a checkpoint each epoch (Each checkpoint will rewrite the previous one)
    - save_best: save the model that achieves the higher accuracy in the development set

    Output:
     - float: Accuracy in the development test of the best model
    """

    if not os.path.exists(output_dir):
        print(f"{output_dir} does not exits. We will create it.")
        os.makedirs(output_dir)

    writer: SummaryWriter = SummaryWriter()

    criterion: CrossEntropyLoss = torch.nn.CrossEntropyLoss().to(device)
    model.zero_grad()
    print_message("Training...")
    for epoch in range(num_epoch):
        acc_dev: float = 0.0
        num_batches: int = 0
        step_no: int = 0

        data_loader_train = DataLoader(
            Tedd1104Dataset(
                dataset_dir=train_dir,
                hide_map_prob=hide_map_prob,
                dropout_images_prob=dropout_images_prob,
            ),
            batch_size=batch_size,
            shuffle=True,
            num_workers=os.cpu_count(),
            pin_memory=True,
        )
        start_time: float = time.time()
        step_start_time: float = time.time()
        dataloader_delay: float = 0
        model.train()
        for batch in data_loader_train:

            x = torch.flatten(
                torch.stack(
                    (
                        batch["image1"],
                        batch["image2"],
                        batch["image3"],
                        batch["image4"],
                        batch["image5"],
                    ),
                    dim=1,
                ),
                start_dim=0,
                end_dim=1,
            ).to(device)

            y = batch["y"].to(device)
            dataloader_delay += time.time() - step_start_time

            total_training_examples += len(y)

            if fp16:
                with autocast():
                    outputs = model.forward(x)
                    loss = criterion(outputs, y)
                    loss = loss / accumulation_steps

                running_loss += loss.item()
                scaler.scale(loss).backward()

            else:
                outputs = model.forward(x)
                loss = criterion(outputs, y) / accumulation_steps
                running_loss += loss.item()
                loss.backward()

            if ((step_no + 1) % accumulation_steps == 0) or (
                    step_no + 1 >= len(data_loader_train)
            ):  # If we are in the last bach of the epoch we also want to perform gradient descent
                if fp16:
                    # Gradient clipping
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                else:
                    # Gradient clipping
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    optimizer.step()
                    optimizer.zero_grad()

                total_batches += 1
                num_batches += 1
                scheduler.step(running_loss / total_batches)

                batch_time = round(time.time() - start_time, 2)
                est: float = batch_time * (math.ceil(
                    len(data_loader_train) / accumulation_steps) - num_batches)
                print_message(
                    f"EPOCH: {initial_epoch + epoch}. "
                    f"{num_batches} of {math.ceil(len(data_loader_train)/accumulation_steps)} batches. "
                    f"Total examples used for training {total_training_examples}. "
                    f"Iteration time: {batch_time} secs. "
                    f"Data Loading bottleneck: {round(dataloader_delay, 2)} secs. "
                    f"Epoch estimated time: "
                    f"{str(datetime.timedelta(seconds=est)).split('.')[0]}")

                print_message(
                    f"Loss: {running_loss / total_batches}. "
                    f"Learning rate {optimizer.state_dict()['param_groups'][0]['lr']}"
                )

                writer.add_scalar("Loss/train", running_loss / total_batches,
                                  total_batches)

                if save_checkpoints and (total_batches + 1) % save_every == 0:
                    print_message("Saving checkpoint...")
                    save_checkpoint(
                        path=os.path.join(output_dir, "checkpoint.pt"),
                        model=model,
                        optimizer_name=optimizer_name,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        running_loss=running_loss,
                        total_batches=total_batches,
                        total_training_examples=total_training_examples,
                        acc_dev=max_acc,
                        epoch=initial_epoch + epoch,
                        fp16=fp16,
                        scaler=None if not fp16 else scaler,
                    )

                dataloader_delay: float = 0
                start_time: float = time.time()

            step_no += 1
            step_start_time = time.time()

        del data_loader_train

        print_message("Dev set evaluation...")

        start_time_eval: float = time.time()

        data_loader_dev = DataLoader(
            Tedd1104Dataset(
                dataset_dir=dev_dir,
                hide_map_prob=0,
                dropout_images_prob=[0, 0, 0, 0, 0],
            ),
            batch_size=batch_size //
            2,  # Use smaller batch size to prevent OOM issues
            shuffle=False,
            num_workers=os.cpu_count() // 2,  # Use less cores to save RAM
            pin_memory=True,
        )

        acc_dev: float = evaluate(
            model=model,
            data_loader=data_loader_dev,
            device=device,
            fp16=fp16,
        )

        del data_loader_dev

        print_message("Test set evaluation...")
        data_loader_test = DataLoader(
            Tedd1104Dataset(
                dataset_dir=test_dir,
                hide_map_prob=0,
                dropout_images_prob=[0, 0, 0, 0, 0],
            ),
            batch_size=batch_size //
            2,  # Use smaller batch size to prevent OOM issues
            shuffle=False,
            num_workers=os.cpu_count() // 2,  # Use less cores to save RAM
            pin_memory=True,
        )

        acc_test: float = evaluate(
            model=model,
            data_loader=data_loader_test,
            device=device,
            fp16=fp16,
        )

        del data_loader_test

        print_message(
            f"Acc dev set: {round(acc_dev*100,2)}. "
            f"Acc test set: {round(acc_test*100,2)}.  "
            f"Eval time: {round(time.time() - start_time_eval,2)} secs.")

        if 0.0 < acc_dev > max_acc and save_best:
            max_acc = acc_dev
            print_message(
                f"New max acc in dev set {round(max_acc, 2)}. Saving model...")
            save_model(
                model=model,
                save_dir=output_dir,
                fp16=fp16,
            )

        writer.add_scalar("Accuracy/dev", acc_dev, epoch)
        writer.add_scalar("Accuracy/test", acc_test, epoch)

    return max_acc