def main():
    feature_train_file = cfg.feature_train_file
    feature_test_file = cfg.feature_test_file
    train_dataset = PlaceDateset(feature_train_file)
    test_dataset = PlaceDateset(feature_test_file)

    train_loader = DataLoader(train_dataset,
                              batch_size=cfg.batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=cfg.batch_size,
                             shuffle=False)
    model = ViT(cfg=cfg,
                feature_seq=16,
                num_classes=1,
                dim=2048,
                depth=8,
                heads=8,
                mlp_dim=1024,
                dropout=0.1,
                emb_dropout=0.1).cuda()
    #model=ViT_cat(cfg=cfg,feature_seq=16,num_classes=2,dim=4096,depth=8,heads=8,mlp_dim=1024,dropout = 0.1,emb_dropout = 0.1).cuda()
    optimizer = optim.__dict__[cfg.optim.name](model.parameters(),
                                               **cfg.optim.setting)
    #在指定的epoch对其进行衰减
    scheduler = optim.lr_scheduler.__dict__[cfg.stepper.name](
        optimizer, **cfg.stepper.setting)

    #criterion1 = nn.CrossEntropyLoss(torch.Tensor(cfg.loss.weight).cuda())
    #criterion1 = nn.BCEWithLogitsLoss()
    criterion1 = FocalLoss(logits=True)

    #加入对数损失
    distance = CosineSimilarity()
    criterion2 = losses.TripletMarginLoss(distance=distance)

    total_loss = list()
    total_epoch = list()
    total_ap = list()
    total_acc = list()
    max_ap = 0

    for epoch in range(0, cfg.epoch):
        train(cfg, model, train_loader, optimizer, scheduler, epoch,
              criterion1, criterion2)
        loss, ap, acc = test(cfg, model, test_loader, criterion1, criterion2)
        total_loss.append(loss)
        total_ap.append(ap)
        total_epoch.append(epoch)
        total_acc.append(acc)
        print('Test Epoch: {} \tloss: {:.6f}\tap: {:.6f}\tacc: {:.6f}'.format(
            epoch, loss, ap, acc))
        if ap > max_ap:
            best_model = model
    save_path = cfg.store + '.pth'
    torch.save(best_model.state_dict(), save_path)

    plt.figure()
    plt.plot(total_epoch, total_loss, 'b-', label=u'loss')
    plt.legend()
    loss_path = cfg.store + "_loss.png"
    plt.savefig(loss_path)

    plt.figure()
    plt.plot(total_epoch, total_ap, 'b-', label=u'AP')
    plt.legend()
    AP_path = cfg.store + "_AP.png"
    plt.savefig(AP_path)

    plt.figure()
    plt.plot(total_epoch, total_acc, 'b-', label=u'acc')
    plt.legend()
    acc_path = cfg.store + "_acc.png"
    plt.savefig(acc_path)
def starting_train(train_dataset,
                   val_dataset,
                   model,
                   hyperparameters,
                   n_eval,
                   summary_path,
                   device='cpu'):
    writer = SummaryWriter()
    model = model.to(device)
    """
    Trains and evaluates a model.

    Args:
        train_dataset:   PyTorch dataset containing training data.
        val_dataset:     PyTorch dataset containing validation data.
        model:           PyTorch model to be trained.
        hyperparameters: Dictionary containing hyperparameters.
        n_eval:          Interval at which we evaluate our model.
        summary_path:    Path where Tensorboard summaries are located.
    """

    # Get keyword arguments
    batch_size, epochs = hyperparameters["batch_size"], hyperparameters[
        "epochs"]

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True)

    # Initalize optimizer (for gradient descent) and loss function
    optimizer = optim.Adam(model.parameters())
    loss_fn = losses.TripletMarginLoss()
    miner = miners.BatchEasyHardMiner(pos_strategy='all', neg_strategy='hard')

    data = pd.read_csv("/content/train.csv")
    data = data.sample(frac=1, random_state=1)  # Shuffle data
    train_data = data.iloc[:int(0.9 * len(data))]
    test_data = data.iloc[int(0.9 * len(data)) + 1:]
    # train_data = data.iloc[0:200]
    # test_data = data.iloc[200:300]

    train_eval_dataset = EvaluationDataset(
        train_data,
        "/content/acmai-team4/corners.csv",
        "/content/train/",
        train=True,
        drop_duplicate_whales=
        True,  # If you set this to True, your evaluation accuracy will be lower!!
        # If you set this to False, evaluate() will take longer!!
        # Recommendation: set this to True during training, and when you're done,
        # create a new dataset with drop_duplicate_whales=False to get a final
        # evaluation metric.
    )
    train_eval_dataset.to(device)

    test_eval_dataset = EvaluationDataset(
        test_data,
        "/content/acmai-team4/corners.csv",
        "/content/train/",
        train=False,
        drop_duplicate_whales=False,
    )

    test_eval_dataset.to(device)

    train_eval_loader = torch.utils.data.DataLoader(train_eval_dataset,
                                                    batch_size=64)
    test_eval_loader = torch.utils.data.DataLoader(test_eval_dataset,
                                                   batch_size=64)

    # Initialize summary writer (for logging)
    if summary_path is not None:
        writer = torch.utils.tensorboard.SummaryWriter(summary_path)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1} of {epochs}")
        trainLosses = []
        model.train()
        # Loop over each batch in the dataset
        step = 1
        for i, batch in enumerate(train_loader):

            batch_inputs = batch[1][0]
            batch_labels = batch[0][0]
            # print("inputs")
            # print(len(batch_inputs))
            # print(batch_inputs)
            # print("labels")
            # print(batch_labels)
            # print(batch[0])
            for j in range(1, 4):  # set to batch_size
                batch_inputs = torch.cat((batch_inputs, batch[1][j]), 0)
                batch_labels = torch.cat((batch_labels, batch[0][j]), 0)

            batch_inputs = batch_inputs.to(device)
            batch_labels = batch_labels.to(device)
            # print(batch_labels)
            print(f"\rIteration {i + 1} of {len(train_loader)} ...", end="")

            #batch_inputs, batch_labels = batch_inputs.to(device), batch_labels.to(device)

            # main body of your training
            optimizer.zero_grad()
            embeddings = model(batch_inputs)  # images is a batch of images
            # print(embeddings)
            hard_triplets = miner(embeddings, batch_labels)
            #batch_outputs = model(batch_inputs)
            # print(f"batch size:\n{batch_outputs.shape()}\n\n")
            loss = loss_fn(embeddings, batch_labels, hard_triplets)
            loss.backward()
            trainLosses.append(loss)
            optimizer.step()
            if step % n_eval == 0:
                # print(len(train_eval_loader))
                accuracy = evaluate(train_eval_loader, test_eval_loader, model)
                print(f"Accuracy: {accuracy}")
                writer.add_scalar(f"Accuracy", accuracy, epoch)

            step += 1
        print('End of epoch loss:',
              round((sum(trainLosses) / len(train_dataset)).item(), 3))

    print()
Beispiel #3
0
def train(data, epochs, triplet_alpha, miner, alpha_pow, alpha_mul):
    device = torch.device("cpu")

    distance = distances.CosineSimilarity()
    if miner == 'batch-hard':
        mining_func = miners.BatchHardMiner(distance=distance)
    elif miner == 'hard':
        mining_func = miners.TripletMarginMiner(margin=triplet_alpha,
                                                distance=distance,
                                                type_of_triplets="hard")
    elif miner == 'semihard':
        mining_func = miners.TripletMarginMiner(margin=triplet_alpha,
                                                distance=distance,
                                                type_of_triplets="semihard")
    elif miner == 'all':
        mining_func = miners.TripletMarginMiner(margin=triplet_alpha,
                                                distance=distance,
                                                type_of_triplets="all")

    loss_func = losses.TripletMarginLoss(margin=triplet_alpha,
                                         distance=distance)

    n_templ = 18

    p_pow = torch.FloatTensor(n_templ, 1, 1).uniform_(1, 1)
    p_pow = torch.tensor(p_pow, device=device, requires_grad=True)

    p_mul = torch.FloatTensor(n_templ, 1, 1).uniform_(1, 1)
    p_mul = torch.tensor(p_mul, device=device, requires_grad=True)

    opt_mul = optim.SGD([p_mul], lr=alpha_mul)
    opt_pow = optim.SGD([p_pow], lr=alpha_pow)

    __tensors_mul = []
    __tensors_pow = []
    __loss = []

    name = "{}_{}_{}_{}".format(triplet_alpha, miner, alpha_pow, alpha_mul)
    os.mkdir('/home/y.kozhevnikov/rust/{}'.format(name))

    for epoch in range(epochs):
        epoch_losses = []
        for word in tqdm(data):
            opt_mul.zero_grad()
            opt_pow.zero_grad()
            golds, vectors = data[word]
            golds = torch.tensor(golds)
            vectors = vectors.to(device)
            vectors = unite(vectors, p_mul, p_pow)

            indices_tuple = mining_func(vectors, golds)

            loss = loss_func(vectors, golds, indices_tuple)

            epoch_losses.append(loss.item())
            loss.backward()

            opt_mul.step()
            opt_pow.step()

        print(p_mul.view(n_templ))
        print(p_pow.view(n_templ))
        __tensors_mul.append(p_mul.clone().detach())
        __tensors_pow.append(p_pow.clone().detach())
        epoch_loss = torch.mean(torch.tensor(epoch_losses))
        __loss.append(epoch_loss)
        print(epoch_loss)

    for n, t in enumerate(__tensors_mul):
        torch.save(t, "/home/y.kozhevnikov/rust/{}/mul_{}.pt".format(name, n))
    for n, t in enumerate(__tensors_pow):
        torch.save(t, "/home/y.kozhevnikov/rust/{}/pow_{}.pt".format(name, n))
    with open("/home/y.kozhevnikov/rust/{}/loss.txt".format(name), 'w') as f:
        for l in __loss:
            f.write("{}\n".format(l))

    return p_pow, p_mul
def main():
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print('Running on device: {}'.format(device))

    # Data transformations
    trans_train = transforms.Compose([
        transforms.RandomApply(transforms=[
            transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET),
            # transforms.RandomPerspective(distortion_scale=0.6, p=1.0),
            transforms.RandomRotation(degrees=(0, 180)),
            transforms.RandomHorizontalFlip(),
        ]),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization,
    ])

    trans_val = transforms.Compose([
        # transforms.CenterCrop(120),
        np.float32,
        transforms.ToTensor(),
        fixed_image_standardization,
    ])

    train_dataset = datasets.ImageFolder(os.path.join(data_dir,
                                                      "train_aligned"),
                                         transform=trans_train)
    val_dataset = datasets.ImageFolder(os.path.join(data_dir, "val_aligned"),
                                       transform=trans_val)

    # Prepare the model
    model = InceptionResnetV1(classify=False,
                              pretrained="vggface2",
                              dropout_prob=0.5).to(device)
    # model = torch.load(model_path).to(device)

    # for param in list(model.parameters())[:-8]:
    #     param.requires_grad = False

    trunk_optimizer = torch.optim.RMSprop(model.parameters(), lr=LR)

    # Set the loss function
    loss = losses.TripletMarginLoss(margin=margin_p)

    # Set the mining function
    # miner = miners.BatchEasyHardMiner(
    #     pos_strategy=miners.BatchEasyHardMiner.EASY,
    #     neg_strategy=miners.BatchEasyHardMiner.SEMIHARD)

    miner = miners.BatchHardMiner()
    # Set the dataloader sampler
    sampler = samplers.MPerClassSampler(
        train_dataset.targets, m=4, length_before_new_iter=len(train_dataset))

    # Package the above stuff into dictionaries.
    models = {"trunk": model}
    optimizers = {"trunk_optimizer": trunk_optimizer}
    loss_funcs = {"metric_loss": loss}
    mining_funcs = {"tuple_miner": miner}

    # Create the tester
    record_keeper, _, _ = logging_presets.get_record_keeper(
        "logs", "tensorboard")
    hooks = logging_presets.get_hook_container(record_keeper)

    dataset_dict = {"val": val_dataset, "train": train_dataset}
    model_folder = "training_saved_models"

    def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname,
                        *args):
        logging.info("UMAP plot for the {} split and label set {}".format(
            split_name, keyname))
        label_set = np.unique(labels)
        num_classes = len(label_set)
        fig = plt.figure(figsize=(8, 7))
        plt.gca().set_prop_cycle(
            cycler("color", [
                plt.cm.nipy_spectral(i)
                for i in np.linspace(0, 0.9, num_classes)
            ]))
        for i in range(num_classes):
            idx = labels == label_set[i]
            plt.plot(umap_embeddings[idx, 0],
                     umap_embeddings[idx, 1],
                     ".",
                     markersize=1)
        plt.show()

    tester = testers.GlobalEmbeddingSpaceTester(
        end_of_testing_hook=hooks.end_of_testing_hook,
        dataloader_num_workers=4,
        accuracy_calculator=AccuracyCalculator(
            include=['mean_average_precision_at_r'], k="max_bin_count"))

    end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                                dataset_dict,
                                                model_folder,
                                                patience=15,
                                                splits_to_eval=[('val',
                                                                 ['train'])])

    # Create the trainer
    trainer = trainers.MetricLossOnly(
        models,
        optimizers,
        batch_size,
        loss_funcs,
        mining_funcs,
        train_dataset,
        sampler=sampler,
        dataloader_num_workers=8,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook)

    trainer.train(num_epochs=num_epochs)
Beispiel #5
0
    def _init_state_test2(self) -> None:
        """
               Initialize the state and load it from an existing checkpoint if any
               """
        """
         Initialize the state and load it from an existing checkpoint if any
         """
        torch.manual_seed(0)
        np.random.seed(0)
        print("Create data loaders", flush=True)

        Input_size_Image = self._train_cfg.input_size

        Test_size = Input_size_Image
        print("Input size : " + str(Input_size_Image))
        print("Test size : " + str(Input_size_Image))
        print("Initial LR :" + str(self._train_cfg.lr))

        transf = get_transforms(input_size=Input_size_Image,
                                test_size=Test_size,
                                kind='full',
                                crop=True,
                                need=('train', 'val'),
                                backbone=None)
        transform_train = transf['train']
        transform_test = transf['val']

        self.train_set = datasets.ImageFolder(self._train_cfg.imnet_path +
                                              '/train',
                                              transform=transform_train)
        self.test_set = datasets.ImageFolder(self._train_cfg.imnet_path +
                                             '/val',
                                             transform=transform_test)

        self.train_dataset = self.train_set
        self.val_dataset = self.test_set

        # self.train_dataset = ClassDisjointMURA(self.train_set, transform_train)
        # self.val_dataset = ClassDisjointMURA(self.test_set, transform_test)

        print(
            f"Total batch_size: {self._train_cfg.batch_per_gpu * self._train_cfg.num_tasks}",
            flush=True)
        print("Create distributed model", flush=True)

        model = models.resnet152(pretrained=False)
        num_ftrs = model.fc.in_features
        model.fc = common_functions.Identity()

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        model = torch.nn.DataParallel(model.to(device))
        embedder = torch.nn.DataParallel(MLP([num_ftrs, 512]).to(device))

        # Set optimizers
        trunk_optimizer = torch.optim.Adam(model.parameters(),
                                           lr=0.0001,
                                           weight_decay=0.0001)
        embedder_optimizer = torch.optim.Adam(embedder.parameters(),
                                              lr=0.0001,
                                              weight_decay=0.0001)

        # Set the loss function
        loss = losses.TripletMarginLoss(margin=0.1)

        # Set the mining function
        miner = miners.MultiSimilarityMiner(epsilon=0.1)

        # Set the dataloader sampler
        self.sampler = samplers.MPerClassSampler(self.train_dataset.targets,
                                                 m=4,
                                                 length_before_new_iter=len(
                                                     self.train_dataset))

        # Package the above stuff into dictionaries.
        self.models_dict = {"trunk": model, "embedder": embedder}
        self.optimizers = {
            "trunk_optimizer": trunk_optimizer,
            "embedder_optimizer": embedder_optimizer
        }
        self.loss_funcs = {"metric_loss": loss}
        self.mining_funcs = {"tuple_miner": miner}
Beispiel #6
0
    # train decoder

    encoder = Encoder(
        numerical_input_dim=num_input_dim,
        cat_vocab_sizes=cat_vocab_sizes,
        cat_embedding_dim=cat_embedding_dim,
        embedding_dim=EMBEDDING_DIM,
    )
    encoder.to(device)
    encoder.train()
    optimizer = optim.Adam(encoder.parameters(), lr=LR)

    distance = distances.CosineSimilarity()
    reducer = reducers.ThresholdReducer(low=0)  # basically, returns average
    loss_func = losses.TripletMarginLoss(margin=0.4,
                                         distance=distance,
                                         reducer=reducer)
    mining_func = miners.TripletMarginMiner(margin=0.4,
                                            distance=distance,
                                            type_of_triplets="semihard")

    train_losses = train_ml_model(encoder, NUM_EPOCHS, dataloader,
                                  NUM_OF_SUBSEQUENCES, mining_func, loss_func,
                                  optimizer)
    fig, axs = plt.subplots(figsize=(12, 6))

    plt.plot(train_losses, label='train')
    plt.xlabel('iter')
    plt.ylabel('loss')
    plt.title("final accuracy: {training}")
    plt.savefig(f'plots/ML_{arch}_{EMBEDDING_DIM}_{NUM_OBS}_{NUM_EPOCHS}.png')
def starting_train_updated(train_dataset, val_dataset, model, hyperparameters,
                           n_eval, summary_path, device, bbox_path,
                           train_path):
    """
    Trains and evaluates a model.

    Args:
        train_dataset:   PyTorch dataset containing training data.
        val_dataset:     PyTorch dataset containing validation data.
        model:           PyTorch model to be trained.
        hyperparameters: Dictionary containing hyperparameters.
        n_eval:          Interval at which we evaluate our model.
        summary_path:    Path where Tensorboard summaries are located.
    """
    model.to(device)
    save_path = './model.pt'

    if (path.exists(save_path)):
        model.load_state_dict(torch.load(save_path))

    data = pd.read_csv(train_path)
    data = data.sample(frac=1)  # Shuffle data
    train_eval_data = data.iloc[:int(PERCENT_TRAIN * len(data))]
    test_eval_data = data.iloc[int(PERCENT_TRAIN * len(data)) + 1:]
    # Get keyword arguments
    batch_size, epochs = hyperparameters["batch_size"], hyperparameters[
        "epochs"]

    train_eval_dataset = EvaluationDataset(
        train_eval_data,
        bbox_path,
        "/content/train",
        train=True,
        drop_duplicate_whales=
        True,  # If you set this to True, your evaluation accuracy will be lower!!
        # If you set this to False, evaluate() will take longer!!
        # Recommendation: set this to True during training, and when you're done,
        # create a new dataset with drop_duplicate_whales=False to get a final
        # evaluation metric.
    )
    train_eval_dataset.to(device)

    test_eval_dataset = EvaluationDataset(
        test_eval_data,
        bbox_path,
        "/content/train",
        train=False,
        drop_duplicate_whales=False,
    )
    test_eval_dataset.to(device)

    train_eval_loader = torch.utils.data.DataLoader(train_eval_dataset,
                                                    batch_size=BATCH_SIZE)
    test_eval_loader = torch.utils.data.DataLoader(test_eval_dataset,
                                                   batch_size=BATCH_SIZE)

    # Initialize dataloaders
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=True)

    # Initalize optimizer (for gradient descent) and loss function
    optimizer = optim.Adam(model.parameters())
    loss_fn = losses.TripletMarginLoss(margin=hyperparameters["margin"])
    miner = miners.BatchEasyHardMiner(pos_strategy='all', neg_strategy='hard')

    # Initialize summary writer (for logging)
    if summary_path is not None:
        writer = torch.utils.tensorboard.SummaryWriter(summary_path)

    step = 0
    correct = 0
    total = 0
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1} of {epochs}")

        # Loop over each batch in the dataset
        for i, batch in enumerate(train_loader):

            # TODO: unload each "item grouping" from batches
            # Likely can just use: torch.cat(batch)

            print(f"\rIteration {i + 1} of {len(train_loader)} ...", end="")
            images, labels = batch

            images = torch.cat(list(images))
            labels = torch.cat(list(labels))
            images = images.to(device)
            labes = labels.to(device)

            optimizer.zero_grad()
            embeddings = model(images)  # images is a batch of images
            hard_triplets = miner(embeddings, labels)
            loss = loss_fn(embeddings, labels, hard_triplets)
            loss.backward()
            optimizer.step()

            if (step % 100 == 0):
                print('Evaluating...')
                print(
                    'Current accuracy: ' +
                    str(evaluate(train_eval_loader, test_eval_loader, model)))

            step += 1
            torch.save(model.state_dict(), save_path)

        print()

    prin
Beispiel #8
0
        for param in model.parameters():
            param.requires_grad = False

    # Make sure out feature dim could match the classifier
    if BACKBONE == 'resnet34':
        in_feature = 512
    else:
        in_feature = 2048

    if MARGIN_TYPE == 'arcMargin':
        # margin = losses.ArcFaceLoss
        margin = ArcMarginProduct(in_feature, num_classes, easy_margin=True)
    elif MARGIN_TYPE == 'inner':
        margin = InnerProduct(in_feature, num_classes)
    elif MARGIN_TYPE == 'tripletMargin':
        margin = losses.TripletMarginLoss(margin=0.1)

    for param in margin.parameters():
        param.requires_grad = True

    # Make model in device 0
    # Set parallel. Single GPU is also okay.
    model.cuda(device_ids[0])
    margin.cuda(device_ids[0])

    # Set optimizer and learning strategy
    if OPTIMIZER == 'adam':
        optimizer = torch.optim.Adam([{
            'params': model.parameters()
        }, {
            'params': margin.parameters()
Beispiel #9
0
    def __init__(self,
                 num_classes=101,
                 embedding_size=512,
                 trunk_architecture="efficientnet-b0",
                 trunk_optim="RMSprop",
                 embedder_optim="RMSprop",
                 classifier_optim="RMSprop",
                 trunk_lr=1e-4,
                 embedder_lr=1e-3,
                 classifier_lr=1e-3,
                 weight_decay=1.5e-6,
                 trunk_decay=0.98,
                 embedder_decay=0.93,
                 classifier_decay=0.93,
                 log_train=True,
                 gpu_id=0):
        """
        Inputs:
            num_classes int: Number of Classes (for Classifier purely)
            embedding_size int: The size of embedding space output from Embedder
            trunk_architecture str: To pass to self.get_trunk() either efficientnet-b{i} or resnet-18/50 or mobilenet
            trunk_optim optim: Which optimizer to use, such as adamW
            embedder_optim optim: Which optimizer to use, such as adamW
            classifier_optim optim: Which optimizer to use, such as adamW
            trunk_lr float: The learning rate for the Trunk Optimizer
            embedder_lr float: The learning rate for the Embedder Optimizer
            classifier_lr float: The learning rate for the Classifier Optimizer
            weight_decay float: The weight decay for all 3 optimizers
            trunk_decay float: The multiplier for the Scheduler y_{t+1} <- trunk_decay * y_{t}
            embedder_decay float: The multiplier for the Scheduler y_{t+1} <- embedder_decay * y_{t}
            classifier_decay float: The multiplier for the Scheduler y_{t+1} <- classifier_decay * y_{t}
            log_train Bool: whether or not to save training logs
            gpu_id Int: Only currently used to track the GPU useage
        """

        self.gpu_id = gpu_id
        #self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = torch.device(f"cuda")
        self.pretrained = False  # this is used to load the indices for train/val data for now
        self.log_train = log_train

        # build three stage network
        self.num_classes = num_classes
        self.embedding_size = embedding_size
        self.MLP_neurons = 2048  # output size of neural network + size used inside embedder/classifier MLP

        self.get_trunk(trunk_architecture)
        self.trunk = nn.DataParallel(self.trunk.to(self.device))
        self.embedder = nn.DataParallel(
            Network(layer_sizes=[self.MLP_neurons, self.embedding_size],
                    neuron_fc=self.MLP_neurons).to(self.device))
        self.classifier = nn.DataParallel(
            Network(layer_sizes=[self.embedding_size, self.num_classes],
                    neuron_fc=self.MLP_neurons).to(self.device))

        # build optimizers
        self.trunk_optimizer = self.get_optimizer(trunk_optim,
                                                  self.trunk.parameters(),
                                                  lr=trunk_lr,
                                                  weight_decay=weight_decay)
        self.embedder_optimizer = self.get_optimizer(
            embedder_optim,
            self.embedder.parameters(),
            lr=embedder_lr,
            weight_decay=weight_decay)
        self.classifier_optimizer = self.get_optimizer(
            classifier_optim,
            self.classifier.parameters(),
            lr=classifier_lr,
            weight_decay=weight_decay)

        # build schedulers
        self.trunk_scheduler = ExponentialLR(self.trunk_optimizer,
                                             gamma=trunk_decay)
        self.embedder_scheduler = ExponentialLR(self.embedder_optimizer,
                                                gamma=embedder_decay)
        self.classifier_scheduler = ExponentialLR(self.classifier_optimizer,
                                                  gamma=classifier_decay)

        # build pair based losses and the miner
        self.triplet = losses.TripletMarginLoss(margin=0.2).to(self.device)
        self.multisimilarity = losses.MultiSimilarityLoss(alpha=2,
                                                          beta=50,
                                                          base=1).to(
                                                              self.device)
        self.miner = miners.MultiSimilarityMiner(epsilon=0.1)
        # build proxy anchor loss
        self.proxy_anchor = Proxy_Anchor(nb_classes=num_classes,
                                         sz_embed=embedding_size,
                                         mrg=0.2,
                                         alpha=32).to(self.device)
        self.proxy_optimizer = AdamW(self.proxy_anchor.parameters(),
                                     lr=trunk_lr * 10,
                                     weight_decay=1.5E-6)
        self.proxy_scheduler = ExponentialLR(self.proxy_optimizer, gamma=0.8)
        # finally crossentropy loss
        self.crossentropy = torch.nn.CrossEntropyLoss().to(self.device)

        # log some of this information
        self.model_params = {
            "Trunk_Model":
            trunk_architecture,
            "Optimizers": [
                str(self.trunk_optimizer),
                str(self.embedder_optimizer),
                str(self.classifier_optimizer)
            ],
            "Embedder":
            str(self.embedder),
            "Embedding_Dimension":
            str(embedding_size),
            "Weight_Decay":
            weight_decay,
            "Scheduler_Decays":
            [trunk_decay, embedder_decay, classifier_decay],
            "Embedding_Size":
            embedding_size,
            "Learning_Rates": [trunk_lr, embedder_lr, classifier_lr],
            "Miner":
            str(self.miner)
        }
Beispiel #10
0
def get_loss():
    return losses.TripletMarginLoss(margin=0.1)
Beispiel #11
0
    #############################
    #
    # Model
    #
    #############################

    net = load_model(args.checkpoint, 'cpu')

    #############################
    #
    # Loss and optimizer
    #
    #############################
    # miner = miners.MultiSimilarityMiner(epsilon=0.1).to(device)
    miner = miners.BatchHardMiner().to(device)
    loss_func = losses.TripletMarginLoss(margin=0.3).to(device)

    #############################
    #
    # Resume
    #
    #############################

    if dataset_name == 'DeepFashion':
        checkpoint_df_model_name = '{}/{}/{}'.format(checkpoint_path,
                                                     dataset_name, model_name)
        checkpoint_file_name = '{}/MLP_G_Weighting/{}/{}_{}_Finch_{}_{}'.format(
            checkpoint_path, dataset_name, model_name, combinations_type,
            finch_partition_number_, weighting_method_)

    if args.resume_df:
Beispiel #12
0
# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Batch size
batch_size = 32

# Data set
train_path = '/lab/vislab/DATA/CUB/images/train/'
# train_path = '/lab/vislab/DATA/just/infilling/samples/places2/mini/'

# Inpainting mask path
mask_path = './samples/places2/mask/'

# Loss function
criterion = losses.TripletMarginLoss(margin=0.05, triplets_per_anchor="all")
# criterion = torch.nn.CosineEmbeddingLoss()

# Mask to use in random masking
mask = Image.open('./samples/places2/mask/mask_01.png')

transformations = transforms.Compose([
    transforms.Resize((256, 256)),
    # RandomMask(mask),
    transforms.ToTensor(),
])

dataset = datasets.ImageFolder(train_path, transformations)
# test_set = ...

# train_sampler = torch.utils.data.RandomSampler(train_set)
Beispiel #13
0
def train(args):

    ### Prepare Dataset
    if args.data_format=="coco":  #TODO
        from datasets.cocodataset import COCODetection
        trainset = COCODetection(image_path=args.train_images,
                            info_file=args.train_info,
                            transform=SSDAugmentation(MEANS))
        valset = None
    else:
        from datasets.customdataset import CustomDataset
        trainset = CustomDataset(
            image_path=args.train_imgdir,
            info_file=args.train_annofile,
            num_class = args.num_class,
            transform=transforms["train"])
        valset = CustomDataset(
            image_path=args.val_imgdir,
            info_file=args.val_annofile,
            num_class=args.num_class,
            transform=transforms["val"])

    trainloader = DataLoader(trainset, args.train_batch_size,
                             shuffle=True, num_workers=args.num_workers)
    valloader = DataLoader(valset, args.val_batch_size,
                             shuffle=False, num_workers=args.num_workers)
    # batch_num = len(trainloader)


    ### Init Model
    model = getattr(models, args.backbone)(pretrained = args.backbone_pretrained and not args.pretrained_model)
    model = EmbClsNet(model, args.num_class)

    ### Pretrained Model
    if args.pretrained_model:
        model.load_state_dict(torch.load(args.pretrained_model))


    ### Data parallel
    IS_DP_AVAILABLE = False
    try:
        devices = list(map(int,args.devices.strip().split(",")))
        if len(devices)>=2:
            IS_DP_AVAILABLE = True
    except:
        logging.warning(f"Format of args.devices is invalid. {args.devices}")

    if IS_DP_AVAILABLE:
        model = torch.nn.DataParallel(model)
    if args.cuda:
        model = model.cuda()

    ### Init Optimizer
    # optimizer = optim.SGD(model.parameters(), lr=args.start_lr,
    #                       momentum=args.momentum, weight_decay=args.weight_decay)
    optimizer = optim.RMSprop(model.parameters(), lr=args.start_lr, alpha=0.9, eps=1e-08,
                         momentum=args.momentum, weight_decay=args.weight_decay)


    ### Init Triplet Loss
    loss_func = tripletloss.TripletMarginLoss(triplets_per_anchor=args.triplets_per_anchor, margin=args.triplet_margin)
    if args.mining:
        miner = tripletminer.MultiSimilarityMiner(epsilon=args.mining_epsilon)


    interval = -1
    for epoch in range(args.max_epoch):
        for batch_idx, batch_data in enumerate(trainloader):
            interval+=1

            imgs, labels = batch_data

            optimizer.zero_grad()
            embeddings, scores = model(imgs)

            # Metric Learning
            if args.mining:
                hard_pairs = miner(embeddings, labels)
                loss = loss_func(embeddings, labels, hard_pairs)
            else:
                loss = loss_func(embeddings, labels)
            loss.backward()
            optimizer.step()

            # Print Loss
            if args.print_interval % interval == 0:
                logging.info(f"[{epoch:%4d}/{args.max_epoch:%4d}] {interval:%7d} Triplet Loss: {loss}")

            # Validation
            if args.val_interval% interval == 0 and interval!=0:
                logging.info(f"[{epoch}/{args.max_epoch}] Starting Validating..")
                for valbatch_idx, valbatch_data in enumerate(valloader):
                    val_imgs, val_labels = valbatch_data

                    val_embeddings, val_scores = model(val_imgs)

                    # Metric Learning
                    val_loss = loss_func(val_embeddings, val_labels)

                    cls_acc = calculate_class_accuracy(val_scores, val_labels)

                    if args.val_print_interval % valbatch_idx == 0:
                        logging.info(f"[{epoch:%4d}/{args.max_epoch:%4d}] {interval:%7d} Triplet Loss: {val_loss} Cls Acc: {cls_acc}")
Beispiel #14
0
def triplet_margin_loss(trial, margin_min=0.0, margin_max=2.0, **kwargs):
    margin = trial.suggest_uniform("margin", margin_min, margin_max)

    loss = losses.TripletMarginLoss(margin=margin, **sample_regularizer(trial))

    return {"loss": loss}
Beispiel #15
0
def train(train_data, test_data, save_model, num_epochs, lr, embedding_size,
          batch_size):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set trunk model and replace the softmax layer with an identity function
    trunk = torchvision.models.resnet18(pretrained=True)
    trunk_output_size = trunk.fc.in_features
    trunk.fc = common_functions.Identity()
    trunk = torch.nn.DataParallel(trunk.to(device))

    # Set embedder model. This takes in the output of the trunk and outputs 64 dimensional embeddings
    embedder = torch.nn.DataParallel(
        MLP([trunk_output_size, embedding_size]).to(device))

    # Set optimizers
    trunk_optimizer = torch.optim.Adam(trunk.parameters(),
                                       lr=lr / 10,
                                       weight_decay=0.0001)
    embedder_optimizer = torch.optim.Adam(embedder.parameters(),
                                          lr=lr,
                                          weight_decay=0.0001)

    # Set the loss function
    loss = losses.TripletMarginLoss(margin=0.1)

    # Set the mining function
    miner = miners.MultiSimilarityMiner(epsilon=0.1)

    # Set the dataloader sampler
    sampler = samplers.MPerClassSampler(train_data.targets,
                                        m=4,
                                        length_before_new_iter=len(train_data))

    save_dir = os.path.join(
        save_model, ''.join(str(lr).split('.')) + '_' + str(batch_size) + '_' +
        str(embedding_size))

    os.makedirs(save_dir, exist_ok=True)

    # Package the above stuff into dictionaries.
    models = {"trunk": trunk, "embedder": embedder}
    optimizers = {
        "trunk_optimizer": trunk_optimizer,
        "embedder_optimizer": embedder_optimizer
    }
    loss_funcs = {"metric_loss": loss}
    mining_funcs = {"tuple_miner": miner}

    record_keeper, _, _ = logging_presets.get_record_keeper(
        os.path.join(save_dir, "example_logs"),
        os.path.join(save_dir, "example_tensorboard"))
    hooks = logging_presets.get_hook_container(record_keeper)

    dataset_dict = {"val": test_data, "train": train_data}
    model_folder = "example_saved_models"

    def visualizer_hook(umapper, umap_embeddings, labels, split_name, keyname,
                        *args):
        logging.info("UMAP plot for the {} split and label set {}".format(
            split_name, keyname))
        label_set = np.unique(labels)
        num_classes = len(label_set)
        fig = plt.figure(figsize=(20, 15))
        plt.title(str(split_name) + '_' + str(num_embeddings))
        plt.gca().set_prop_cycle(
            cycler("color", [
                plt.cm.nipy_spectral(i)
                for i in np.linspace(0, 0.9, num_classes)
            ]))
        for i in range(num_classes):
            idx = labels == label_set[i]
            plt.plot(umap_embeddings[idx, 0],
                     umap_embeddings[idx, 1],
                     ".",
                     markersize=1)
        plt.show()

    # Create the tester
    tester = testers.GlobalEmbeddingSpaceTester(
        end_of_testing_hook=hooks.end_of_testing_hook,
        visualizer=umap.UMAP(),
        visualizer_hook=visualizer_hook,
        dataloader_num_workers=32,
        accuracy_calculator=AccuracyCalculator(k="max_bin_count"))

    end_of_epoch_hook = hooks.end_of_epoch_hook(tester,
                                                dataset_dict,
                                                model_folder,
                                                test_interval=1,
                                                patience=1)

    trainer = trainers.MetricLossOnly(
        models,
        optimizers,
        batch_size,
        loss_funcs,
        mining_funcs,
        train_data,
        sampler=sampler,
        dataloader_num_workers=32,
        end_of_iteration_hook=hooks.end_of_iteration_hook,
        end_of_epoch_hook=end_of_epoch_hook)

    trainer.train(num_epochs=num_epochs)

    if save_model is not None:

        torch.save(models["trunk"].state_dict(),
                   os.path.join(save_dir, 'trunk.pth'))
        torch.save(models["embedder"].state_dict(),
                   os.path.join(save_dir, 'embedder.pth'))

        print('Model saved in ', save_dir)
def train_eval(args, train_data, dev_data):
    logger = logging.getLogger("main")
    # Create dataset & dataloader
    trans = [
        transforms.Resize((224, 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    trans = transforms.Compose(trans)

    train_dataset, train_char_idx = \
        create_dataset(args.root, train_data, trans)

    train_sampler = MetricBatchSampler(train_dataset,
                                       train_char_idx,
                                       n_max_per_char=args.n_max_per_char,
                                       n_batch_size=args.n_batch_size,
                                       n_random=args.n_random)
    train_dataloader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  collate_fn=collate_fn)
    # number of batches given to trainer
    n_batch = int(len(train_dataloader))

    eval_train_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split*3, train_data, trans)
    eval_dev_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split, dev_data, trans)

    # Construct model & optimizer
    device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)

    trunk = models.resnet18(pretrained=True)
    trunk_output_size = trunk.fc.in_features
    trunk.fc = Identity()
    trunk.to(device)

    model = nn.Sequential(nn.Linear(trunk_output_size, args.emb_dim),
                          Normalize())
    model.to(device)

    if args.optimizer == "SGD":
        trunk_optimizer = torch.optim.SGD(trunk.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
        model_optimizer = torch.optim.SGD(model.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
    else:
        raise NotImplementedError

    loss_func = losses.TripletMarginLoss(margin=args.margin,
                                         normalize_embeddings=args.normalize)

    best_dev_eer = 1.0
    i_epoch = 0

    def end_of_epoch_hook(trainer):
        nonlocal i_epoch, best_dev_eer

        logger.info(f"EPOCH\t{i_epoch}")

        if i_epoch % args.eval_freq == 0:
            train_eer, train_eer_std = evaluate(args, trainer.models["trunk"],
                                                trainer.models["embedder"],
                                                eval_train_dataloaders)
            dev_eer, dev_eer_std = evaluate(args, trainer.models["trunk"],
                                            trainer.models["embedder"],
                                            eval_dev_dataloaders)
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                train_eer, train_eer_std))
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                dev_eer, dev_eer_std))
            if dev_eer < best_dev_eer:
                logger.info("New best model!")
                best_dev_eer = dev_eer

        i_epoch += 1

    trainer = trainers.MetricLossOnly(
        models={
            "trunk": trunk,
            "embedder": model
        },
        optimizers={
            "trunk_optimizer": trunk_optimizer,
            "embedder_optimizer": model_optimizer
        },
        batch_size=None,
        loss_funcs={"metric_loss": loss_func},
        mining_funcs={},
        iterations_per_epoch=n_batch,
        dataset=train_dataset,
        data_device=None,
        loss_weights=None,
        sampler=train_sampler,
        collate_fn=collate_fn,
        lr_schedulers=None,  #TODO: use warm-up,
        end_of_epoch_hook=end_of_epoch_hook,
        dataloader_num_workers=1)

    trainer.train(num_epochs=args.epoch)

    if args.save_model:
        torch.save(trainer.models, f"model/{args.suffix}.mdl")

    return best_dev_eer
Beispiel #17
0
def main():
    muti_train_file = cfg.muti_train_file
    muti_test_file = cfg.muti_test_file
    train_dataset = MutiDateset(muti_train_file)
    test_dataset = MutiDateset(muti_test_file)

    train_loader = DataLoader(train_dataset,batch_size=cfg.batch_size,shuffle=True)
    test_loader = DataLoader(test_dataset,batch_size=cfg.batch_size ,shuffle=False)
    model1 = Dense_fenlei(num_classes=2,dim = 2048,dropout = 0.5).cuda()
    model2 = ViT(cfg=cfg,feature_seq=16,num_classes=1,dim=2048,depth=8,heads=8,mlp_dim=1024,dropout = 0.1,emb_dropout = 0.1,batch_normalization=False).cuda()
    model3 = ViT(cfg=cfg,feature_seq=16,num_classes=1,dim=2048,depth=8,heads=8,mlp_dim=1024,dropout = 0.1,emb_dropout = 0.1).cuda()
    optimizer1 = optim.__dict__[cfg.optim1.name](model1.parameters(), **cfg.optim1.setting)
    optimizer2 = optim.__dict__[cfg.optim2.name](model2.parameters(), **cfg.optim2.setting)
    optimizer3 = optim.__dict__[cfg.optim3.name](model3.parameters(), **cfg.optim3.setting)
    #在指定的epoch对其进行衰减
    scheduler = optim.lr_scheduler.__dict__[cfg.stepper.name](optimizer1, **cfg.stepper.setting)

    criterion3 = nn.CrossEntropyLoss(torch.Tensor(cfg.loss.weight).cuda())
    #criterion1 = nn.BCEWithLogitsLoss()
    criterion1 = FocalLoss(logits=True)
    
    #加入对数损失
    distance = CosineSimilarity()
    criterion2 = losses.TripletMarginLoss(distance = distance)


    total_loss, total_loss_place, total_loss_tea=list(), list(), list()
    total_epoch=list()
    total_ap, total_ap_place, total_ap_tea=list(),list(),list()
    total_acc=list()
    max_ap=0
    


    for epoch in range(0,cfg.epoch):
        train_mult(cfg, model1,model2,model3, train_loader, optimizer1,optimizer2, optimizer3, scheduler, epoch, criterion1,criterion2,criterion3)
        loss,loss_place,loss_tea,ap,ap_place,ap_tea,acc=test_mult(cfg, model1, model2, model3, test_loader, criterion1,criterion2,criterion3)
        total_loss.append(loss)
        total_ap.append(ap)
        total_loss_place.append(loss_place)
        total_ap_place.append(ap_place)
        total_loss_tea.append(loss_tea)
        total_ap_tea.append(ap_tea)
        total_epoch.append(epoch)
        total_acc.append(acc)
        print('Test Epoch: {} \tloss: {:.6f}\tap: {:.6f}\tacc: {:.6f}'.format(epoch, loss,ap,acc))
        if ap>max_ap:
            best_model=model3
    save_path=cfg.store+'.pth'
    torch.save(best_model.state_dict(), save_path)
    
    plt.figure(figsize=(20, 20))
    plt.plot(total_epoch,total_loss,'b^',label=u'loss')
    plt.plot(total_epoch,total_loss_place,'y^',label=u'loss_place')
    plt.plot(total_epoch,total_loss_tea,'r^',label=u'loss_tea')
    plt.legend()
    loss_path=cfg.store+"_loss.png"
    plt.savefig(loss_path)
    
    plt.figure(figsize=(20, 20))
    plt.plot(total_epoch,total_ap,'b^',label=u'AP')
    plt.plot(total_epoch,total_ap_place,'y^',label=u'AP_place')
    plt.plot(total_epoch,total_ap_tea,'r^',label=u'AP_tea')
    plt.legend()
    AP_path=cfg.store+"_AP.png"
    plt.savefig(AP_path)
    
    plt.figure()
    plt.plot(total_epoch,total_acc,'b^',label=u'acc')
    plt.legend()
    acc_path=cfg.store+"_acc.png"
    plt.savefig(acc_path)
 def __init__(self, margin=0.1, **kwargs):
     super(TripletLoss, self).__init__()
     self.margin = margin
     self.miner = miners.TripletMarginMiner(margin,
                                            type_of_triplets='semihard')
     self.loss_func = losses.TripletMarginLoss(margin=self.margin)
Beispiel #19
0
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Set the datasets
train_dataset = datasets.CIFAR100(root="CIFAR100_Dataset",
                                  train=True,
                                  transform=train_transform,
                                  download=True)
val_dataset = datasets.CIFAR100(root="CIFAR100_Dataset",
                                train=False,
                                transform=val_transform,
                                download=True)

# Set the loss function
metric_loss = losses.TripletMarginLoss(margin=0.01)
synth_loss = losses.AngularLoss(alpha=35)
g_adv_loss = losses.AngularLoss(alpha=35)

# Set the mining function
miner = miners.MultiSimilarityMiner(epsilon=0.1)

# Set the dataloader sampler
sampler = samplers.MPerClassSampler(train_dataset.targets,
                                    m=4,
                                    length_before_new_iter=3200)

# Set other training parameters
batch_size = 32
num_epochs = 2
Beispiel #20
0
def train_model(model, model_test, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    # best_model_wts = model.state_dict()
    # best_acc = 0.0
    warm_up = 0.1  # We start from the 0.1*lrRate
    warm_iteration = round(dataset_sizes['satellite'] / opt.batchsize) * opt.warm_epoch  # first 5 epoch

    if opt.arcface:
        criterion_arcface = losses.ArcFaceLoss(num_classes=opt.nclasses, embedding_size=512)
    if opt.cosface:
        criterion_cosface = losses.CosFaceLoss(num_classes=opt.nclasses, embedding_size=512)
    if opt.circle:
        criterion_circle = CircleLoss(m=0.25, gamma=32)  # gamma = 64 may lead to a better result.
    if opt.triplet:
        miner = miners.MultiSimilarityMiner()
        criterion_triplet = losses.TripletMarginLoss(margin=0.3)
    if opt.lifted:
        criterion_lifted = losses.GeneralizedLiftedStructureLoss(neg_margin=1, pos_margin=0)
    if opt.contrast:
        criterion_contrast = losses.ContrastiveLoss(pos_margin=0, neg_margin=1)
    if opt.sphere:
        criterion_sphere = losses.SphereFaceLoss(num_classes=opt.nclasses, embedding_size=512, margin=4)

    for epoch in range(num_epochs - start_epoch):
        epoch = epoch + start_epoch
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train']:
            if phase == 'train':
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0
            running_corrects2 = 0.0
            running_corrects3 = 0.0
            # Iterate over data.
            for data, data2, data3, data4 in zip(dataloaders['satellite'], dataloaders['street'], dataloaders['drone'],
                                                 dataloaders['google']):
                # get the inputs
                inputs, labels = data
                inputs2, labels2 = data2
                inputs3, labels3 = data3
                inputs4, labels4 = data4
                now_batch_size, c, h, w = inputs.shape
                if now_batch_size < opt.batchsize:  # skip the last batch
                    continue
                if use_gpu:
                    inputs = Variable(inputs.cuda().detach())
                    inputs2 = Variable(inputs2.cuda().detach())
                    inputs3 = Variable(inputs3.cuda().detach())
                    labels = Variable(labels.cuda().detach())
                    labels2 = Variable(labels2.cuda().detach())
                    labels3 = Variable(labels3.cuda().detach())
                    if opt.extra_Google:
                        inputs4 = Variable(inputs4.cuda().detach())
                        labels4 = Variable(labels4.cuda().detach())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                if phase == 'val':
                    with torch.no_grad():
                        outputs, outputs2 = model(inputs, inputs2)
                else:
                    if opt.views == 2:
                        outputs, outputs2 = model(inputs, inputs2)
                    elif opt.views == 3:
                        if opt.extra_Google:
                            outputs, outputs2, outputs3, outputs4 = model(inputs, inputs2, inputs3, inputs4)
                        else:
                            outputs, outputs2, outputs3 = model(inputs, inputs2, inputs3)

                return_feature = opt.arcface or opt.cosface or opt.circle or opt.triplet or opt.contrast or opt.lifted or opt.sphere

                if opt.views == 2:
                    _, preds = torch.max(outputs.data, 1)
                    _, preds2 = torch.max(outputs2.data, 1)
                    loss = criterion(outputs, labels) + criterion(outputs2, labels2)
                elif opt.views == 3:
                    if return_feature:
                        logits, ff = outputs
                        logits2, ff2 = outputs2
                        logits3, ff3 = outputs3
                        fnorm = torch.norm(ff, p=2, dim=1, keepdim=True)
                        fnorm2 = torch.norm(ff2, p=2, dim=1, keepdim=True)
                        fnorm3 = torch.norm(ff3, p=2, dim=1, keepdim=True)
                        ff = ff.div(fnorm.expand_as(ff))  # 8*512,tensor
                        ff2 = ff2.div(fnorm2.expand_as(ff2))
                        ff3 = ff3.div(fnorm3.expand_as(ff3))
                        loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3)
                        _, preds = torch.max(logits.data, 1)
                        _, preds2 = torch.max(logits2.data, 1)
                        _, preds3 = torch.max(logits3.data, 1)
                        # Multiple perspectives are combined to calculate losses, please join ''--loss_merge'' in run.sh
                        if opt.loss_merge:
                            ff_all = torch.cat((ff, ff2, ff3), dim=0)
                            labels_all = torch.cat((labels, labels2, labels3), dim=0)
                        if opt.extra_Google:
                            logits4, ff4 = outputs4
                            fnorm4 = torch.norm(ff4, p=2, dim=1, keepdim=True)
                            ff4 = ff4.div(fnorm4.expand_as(ff4))
                            loss = criterion(logits, labels) + criterion(logits2, labels2) + criterion(logits3, labels3) +criterion(logits4, labels4)
                            if opt.loss_merge:
                                ff_all = torch.cat((ff_all, ff4), dim=0)
                                labels_all = torch.cat((labels_all, labels4), dim=0)
                        if opt.arcface:
                            if opt.loss_merge:
                                loss += criterion_arcface(ff_all, labels_all)
                            else:
                                loss += criterion_arcface(ff, labels) + criterion_arcface(ff2, labels2) + criterion_arcface(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_arcface(ff4, labels4)  # /now_batch_size
                        if opt.cosface:
                            if opt.loss_merge:
                                loss += criterion_cosface(ff_all, labels_all)
                            else:
                                loss += criterion_cosface(ff, labels) + criterion_cosface(ff2, labels2) + criterion_cosface(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_cosface(ff4, labels4)  # /now_batch_size
                        if opt.circle:
                            if opt.loss_merge:
                                loss += criterion_circle(*convert_label_to_similarity(ff_all, labels_all)) / now_batch_size
                            else:
                                loss += criterion_circle(*convert_label_to_similarity(ff, labels)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff2, labels2)) / now_batch_size + criterion_circle(*convert_label_to_similarity(ff3, labels3)) / now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_circle(*convert_label_to_similarity(ff4, labels4)) / now_batch_size
                        if opt.triplet:
                            if opt.loss_merge:
                                hard_pairs_all = miner(ff_all, labels_all)
                                loss += criterion_triplet(ff_all, labels_all, hard_pairs_all)
                            else:
                                hard_pairs = miner(ff, labels)
                                hard_pairs2 = miner(ff2, labels2)
                                hard_pairs3 = miner(ff3, labels3)
                                loss += criterion_triplet(ff, labels, hard_pairs) + criterion_triplet(ff2, labels2, hard_pairs2) + criterion_triplet(ff3, labels3, hard_pairs3)# /now_batch_size
                                if opt.extra_Google:
                                    hard_pairs4 = miner(ff4, labels4)
                                    loss += criterion_triplet(ff4, labels4, hard_pairs4)
                        if opt.lifted:
                            if opt.loss_merge:
                                loss += criterion_lifted(ff_all, labels_all)
                            else:
                                loss += criterion_lifted(ff, labels) + criterion_lifted(ff2, labels2) + criterion_lifted(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_lifted(ff4, labels4)
                        if opt.contrast:
                            if opt.loss_merge:
                                loss += criterion_contrast(ff_all, labels_all)
                            else:
                                loss += criterion_contrast(ff, labels) + criterion_contrast(ff2,labels2) + criterion_contrast(ff3, labels3)  # /now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_contrast(ff4, labels4)
                        if opt.sphere:
                            if opt.loss_merge:
                                loss += criterion_sphere(ff_all, labels_all) / now_batch_size
                            else:
                                loss += criterion_sphere(ff, labels) / now_batch_size + criterion_sphere(ff2, labels2) / now_batch_size + criterion_sphere(ff3, labels3) / now_batch_size
                                if opt.extra_Google:
                                    loss += criterion_sphere(ff4, labels4)

                    else:
                        _, preds = torch.max(outputs.data, 1)
                        _, preds2 = torch.max(outputs2.data, 1)
                        _, preds3 = torch.max(outputs3.data, 1)
                        if opt.loss_merge:
                            outputs_all = torch.cat((outputs, outputs2, outputs3), dim=0)
                            labels_all = torch.cat((labels, labels2, labels3), dim=0)
                            if opt.extra_Google:
                                outputs_all = torch.cat((outputs_all, outputs4), dim=0)
                                labels_all = torch.cat((labels_all, labels4), dim=0)
                            loss = 4*criterion(outputs_all, labels_all)
                        else:
                            loss = criterion(outputs, labels) + criterion(outputs2, labels2) + criterion(outputs3, labels3)
                            if opt.extra_Google:
                                loss += criterion(outputs4, labels4)

                # backward + optimize only if in training phase
                if epoch < opt.warm_epoch and phase == 'train':
                    warm_up = min(1.0, warm_up + 0.9 / warm_iteration)
                    loss *= warm_up

                if phase == 'train':
                    if fp16:  # we use optimier to backward loss
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()
                    optimizer.step()
                    ##########
                    if opt.moving_avg < 1.0:
                        update_average(model_test, model, opt.moving_avg)

                # statistics
                if int(version[0]) > 0 or int(version[2]) > 3:  # for the new version like 0.4.0, 0.5.0 and 1.0.0
                    running_loss += loss.item() * now_batch_size
                else:  # for the old version like 0.3.0 and 0.3.1
                    running_loss += loss.data[0] * now_batch_size
                running_corrects += float(torch.sum(preds == labels.data))
                running_corrects2 += float(torch.sum(preds2 == labels2.data))
                if opt.views == 3:
                    running_corrects3 += float(torch.sum(preds3 == labels3.data))

            epoch_loss = running_loss / dataset_sizes['satellite']
            epoch_acc = running_corrects / dataset_sizes['satellite']
            epoch_acc2 = running_corrects2 / dataset_sizes['satellite']

            if opt.views == 2:
                print('{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc,
                                                                                         epoch_acc2))
            elif opt.views == 3:
                epoch_acc3 = running_corrects3 / dataset_sizes['satellite']
                print('{} Loss: {:.4f} Satellite_Acc: {:.4f}  Street_Acc: {:.4f} Drone_Acc: {:.4f}'.format(phase,
                                                                                                           epoch_loss,
                                                                                                           epoch_acc,
                                                                                                           epoch_acc2,
                                                                                                           epoch_acc3))

            y_loss[phase].append(epoch_loss)
            y_err[phase].append(1.0 - epoch_acc)
            # deep copy the model
            if phase == 'train':
                scheduler.step()
            last_model_wts = model.state_dict()
            if epoch % 20 == 19:
                save_network(model, opt.name, epoch)
            # draw_curve(epoch)

        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    # print('Best val Acc: {:4f}'.format(best_acc))
    # save_network(model_test, opt.name+'adapt', epoch)

    return model
Beispiel #21
0
model = models.resnet50(pretrained=True)

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

# Batch size
batch_size = 32

# # Data set
train_path = '/lab/vislab/DATA/CUB/images/train/'
test_path = '/lab/vislab/DATA/CUB/images/test/'
# train_path = '/lab/vislab/DATA/just/infilling/samples/places2/mini/'
# test_path = '/lab/vislab/DATA/just/infilling/samples/places2/mini/'

# Loss function
criterion = losses.TripletMarginLoss(
    margin=0.2, triplets_per_anchor="all")  # so we are already doing batchall
# criterion = BatchAllLoss(device, margin=0.2)
# criterion = nn.CrossEntropyLoss()

# criterion = torch.nn.CosineEmbeddingLoss()


class RandomMask(object):
    """Add random occlusions to image.

    Args:
        mask: (Image.Image) - Image to use to occlude.
    """
    def __init__(self, mask):
        assert isinstance(mask, Image.Image)
        self.mask = mask
Beispiel #22
0
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Set the datasets
train_dataset = datasets.CIFAR100(root="CIFAR100_Dataset",
                                  train=True,
                                  transform=train_transform,
                                  download=True)
val_dataset = datasets.CIFAR100(root="CIFAR100_Dataset",
                                train=False,
                                transform=val_transform,
                                download=True)

# Set the loss functions. loss0 will be applied to the first embedder, loss1 to the second embedder etc.
loss0 = losses.TripletMarginLoss(margin=0.01)
loss1 = losses.MultiSimilarityLoss(alpha=0.1, beta=40, base=0.5)
loss2 = losses.ArcFaceLoss(margin=30, num_classes=100,
                           embedding_size=64).to(device)

# Set the mining functions. In this example we'll apply mining to the 2nd and 3rd cascaded outputs.
miner1 = miners.MultiSimilarityMiner(epsilon=0.1)
miner2 = miners.HDCMiner(filter_percentage=0.25)

# Set the dataloader sampler
sampler = samplers.MPerClassSampler(train_dataset.targets, m=4)

# Set other training parameters
batch_size = 32
num_epochs = 2
iterations_per_epoch = 100
    def __init__(self,
                 train_dl,
                 val_dl,
                 unseen_dl,
                 model,
                 optimizer,
                 scheduler,
                 criterion,
                 mining_function,
                 loss,
                 savePath='./models/',
                 device='cuda',
                 BATCH_SIZE=64):
        self.device = device
        self.train_dl = train_dl
        self.val_dl = val_dl
        self.unseen_dl = unseen_dl
        self.BATCH_SIZE = BATCH_SIZE
        self.model = model.to(self.device)
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.criterion = criterion
        self.mining_function = mining_function
        self.loss = loss
        self.distance = distances.LpDistance(normalize_embeddings=True,
                                             p=2,
                                             power=1)
        self.reducer = reducers.ThresholdReducer(low=0)
        self.regularizer = regularizers.LpRegularizer(p=2)
        if self.mining_function == 'triplet':
            self.mining_func = miners.TripletMarginMiner(
                margin=0.01,
                distance=self.distance,
                type_of_triplets="semihard")
        elif self.mining_function == 'pair':
            self.mining_func = miners.PairMarginMiner(pos_margin=0,
                                                      neg_margin=0.2)

        if self.loss == 'triplet':
            self.loss_function = losses.TripletMarginLoss(
                margin=0.01, distance=self.distance, reducer=self.reducer)
        elif self.loss == 'contrastive':
            self.loss_function = losses.ContrastiveLoss(pos_margin=0,
                                                        neg_margin=1.5)
        elif self.loss == 'panc':
            self.loss_function = losses.ProxyAnchorLoss(
                9,
                128,
                margin=0.01,
                alpha=5,
                reducer=self.reducer,
                weight_regularizer=self.regularizer)
        elif self.loss == 'pnca':
            self.loss_function = losses.ProxyNCALoss(
                9,
                128,
                softmax_scale=1,
                reducer=self.reducer,
                weight_regularizer=self.regularizer)
        elif self.loss == 'normsoftmax':
            self.loss_function = losses.NormalizedSoftmaxLoss(
                9,
                128,
                temperature=0.05,
                reducer=self.reducer,
                weight_regularizer=self.regularizer)

        if self.loss in ['normsoftmax', 'panc', 'pnca']:
            self.loss_optimizer = optim.SGD(self.loss_function.parameters(),
                                            lr=0.0001,
                                            momentum=0.9)
            self.loss_scheduler = lr_scheduler.ReduceLROnPlateau(
                self.loss_optimizer,
                'min',
                patience=3,
                threshold=0.0001,
                factor=0.1,
                verbose=True)

        self.savePath = savePath + 'efigi{}_{}_128'.format(
            self.mining_function, self.loss)
Beispiel #24
0
def train_eval(args, train_data, dev_data):
    logger = logging.getLogger("main")
    # Create dataset & dataloader
    trans = [
        transforms.Resize((224, 224)),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    trans = transforms.Compose(trans)

    train_dataset, train_char_idx = \
        create_dataset(args.root, train_data, trans)

    train_sampler = MetricBatchSampler(train_dataset,
                                       train_char_idx,
                                       n_max_per_char=args.n_max_per_char,
                                       n_batch_size=args.n_batch_size,
                                       n_random=args.n_random)
    train_dataloader = DataLoader(train_dataset,
                                  batch_sampler=train_sampler,
                                  collate_fn=collate_fn)
    # number of batches given to trainer
    n_batch = int(len(train_dataloader))

    eval_train_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split*3, train_data, trans)
    eval_dev_dataloaders = \
        prepare_evaluation_dataloaders(args, args.eval_split, dev_data, trans)

    # Construct model & optimizer
    device = "cpu" if args.gpu < 0 else "cuda:{}".format(args.gpu)

    trunk, model = create_models(args.emb_dim, args.dropout)
    trunk.to(device)
    model.to(device)

    if args.metric_loss == "triplet":
        loss_func = losses.TripletMarginLoss(
            margin=args.margin,
            normalize_embeddings=args.normalize,
            smooth_loss=args.smooth)
    elif args.metric_loss == "arcface":
        loss_func = losses.ArcFaceLoss(margin=args.margin,
                                       num_classes=len(train_data),
                                       embedding_size=args.emb_dim)
        loss_func.to(device)

    if args.optimizer == "SGD":
        trunk_optimizer = torch.optim.SGD(trunk.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
        model_optimizer = torch.optim.SGD(model.parameters(),
                                          lr=args.lr,
                                          momentum=args.momentum,
                                          weight_decay=args.decay)
        optimizers = {
            "trunk_optimizer": trunk_optimizer,
            "embedder_optimizer": model_optimizer
        }
        if args.metric_loss == "arcface":
            loss_optimizer = torch.optim.SGD(loss_func.parameters(),
                                             lr=args.lr,
                                             momentum=args.momentum,
                                             weight_decay=args.decay)
            optimizers["loss_optimizer"] = loss_optimizer
    elif args.optimizer == "Adam":
        trunk_optimizer = torch.optim.Adam(trunk.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.decay)
        model_optimizer = torch.optim.Adam(model.parameters(),
                                           lr=args.lr,
                                           weight_decay=args.decay)
        optimizers = {
            "trunk_optimizer": trunk_optimizer,
            "embedder_optimizer": model_optimizer
        }
        if args.metric_loss == "arcface":
            loss_optimizer = torch.optim.Adam(loss_func.parameters(),
                                              lr=args.lr,
                                              weight_decay=args.decay)
            optimizers["loss_optimizer"] = loss_optimizer
    else:
        raise NotImplementedError

    def lr_func(step):
        if step < args.warmup:
            return (step + 1) / args.warmup
        else:
            steps_decay = step // args.decay_freq
            return 1 / args.decay_factor**steps_decay

    trunk_scheduler = torch.optim.lr_scheduler.LambdaLR(
        trunk_optimizer, lr_func)
    model_scheduler = torch.optim.lr_scheduler.LambdaLR(
        model_optimizer, lr_func)
    schedulers = {
        "trunk_scheduler": trunk_scheduler,
        "model_scheduler": model_scheduler
    }

    if args.miner == "none":
        mining_funcs = {}
    elif args.miner == "batch-hard":
        mining_funcs = {
            "post_gradient_miner": miners.BatchHardMiner(use_similarity=True)
        }

    best_dev_eer = 1.0
    i_epoch = 0

    def end_of_epoch_hook(trainer):
        nonlocal i_epoch, best_dev_eer

        logger.info(f"EPOCH\t{i_epoch}")

        if i_epoch % args.eval_freq == 0:
            train_eer, train_eer_std = evaluate(args, trainer.models["trunk"],
                                                trainer.models["embedder"],
                                                eval_train_dataloaders)
            dev_eer, dev_eer_std = evaluate(args, trainer.models["trunk"],
                                            trainer.models["embedder"],
                                            eval_dev_dataloaders)
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                train_eer, train_eer_std))
            logger.info("Eval EER (mean, std):\t{}\t{}".format(
                dev_eer, dev_eer_std))
            if dev_eer < best_dev_eer:
                logger.info("New best model!")
                best_dev_eer = dev_eer

        i_epoch += 1

    def end_of_iteration_hook(trainer):
        for scheduler in schedulers.values():
            scheduler.step()

    trainer = trainers.MetricLossOnly(
        models={
            "trunk": trunk,
            "embedder": model
        },
        optimizers=optimizers,
        batch_size=None,
        loss_funcs={"metric_loss": loss_func},
        mining_funcs=mining_funcs,
        iterations_per_epoch=n_batch,
        dataset=train_dataset,
        data_device=None,
        loss_weights=None,
        sampler=train_sampler,
        collate_fn=collate_fn,
        lr_schedulers=None,
        end_of_epoch_hook=end_of_epoch_hook,
        end_of_iteration_hook=end_of_iteration_hook,
        dataloader_num_workers=1)

    trainer.train(num_epochs=args.epoch)

    if args.save_model:
        save_models = {
            "trunk": trainer.models["trunk"].state_dict(),
            "embedder": trainer.models["embedder"].state_dict(),
            "args": [args.emb_dim]
        }
        torch.save(save_models, f"model/{args.suffix}.mdl")

    return best_dev_eer
Beispiel #25
0
def train_app(cfg):
    print(cfg.pretty())

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Set trunk model and replace the softmax layer with an identity function
    trunk = torchvision.models.__dict__[cfg.model.model_name](pretrained=cfg.model.pretrained)
    
    #resnet18(pretrained=True)
    #trunk = models.alexnet(pretrained=True)
    #trunk = models.resnet50(pretrained=True)
    #trunk = models.resnet152(pretrained=True)
    #trunk = models.wide_resnet50_2(pretrained=True)
    #trunk = EfficientNet.from_pretrained('efficientnet-b2')
    trunk_output_size = trunk.fc.in_features
    trunk.fc = Identity()
    trunk = torch.nn.DataParallel(trunk.to(device))

    embedder = torch.nn.DataParallel(MLP([trunk_output_size, cfg.embedder.size]).to(device))
    classifier = torch.nn.DataParallel(MLP([cfg.embedder.size, cfg.embedder.class_out_size])).to(device)

    # Set optimizers
    if cfg.optimizer.name == "sdg":
        trunk_optimizer = torch.optim.SGD(trunk.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay)
        embedder_optimizer = torch.optim.SGD(embedder.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay)
        classifier_optimizer = torch.optim.SGD(classifier.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay)
    elif cfg.optimizer.name == "rmsprop":
        trunk_optimizer = torch.optim.RMSprop(trunk.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay)
        embedder_optimizer = torch.optim.RMSprop(embedder.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay)
        classifier_optimizer = torch.optim.RMSprop(classifier.parameters(), lr=cfg.optimizer.lr, momentum=cfg.optimizer.momentum, weight_decay=cfg.optimizer.weight_decay)



    # Set the datasets
    data_dir = os.environ["DATASET_FOLDER"]+"/"+cfg.dataset.data_dir
    print("Data dir: "+data_dir)

    train_dataset, val_dataset, val_samples_dataset = get_datasets(data_dir, cfg, mode=cfg.mode.type)
    print("Trainset: ",len(train_dataset), "Testset: ",len(val_dataset), "Samplesset: ",len(val_samples_dataset))

    # Set the loss function
    if cfg.embedder_loss.name == "margin_loss":
        loss = losses.MarginLoss(margin=cfg.embedder_loss.margin,nu=cfg.embedder_loss.nu,beta=cfg.embedder_loss.beta)
    if cfg.embedder_loss.name == "triplet_margin":
        loss = losses.TripletMarginLoss(margin=cfg.embedder_loss.margin)
    if cfg.embedder_loss.name == "multi_similarity":
        loss = losses.MultiSimilarityLoss(alpha=cfg.embedder_loss.alpha, beta=cfg.embedder_loss.beta, base=cfg.embedder_loss.base)

    # Set the classification loss:
    classification_loss = torch.nn.CrossEntropyLoss()

    # Set the mining function

    if cfg.miner.name == "triplet_margin":
        #miner = miners.TripletMarginMiner(margin=0.2)
        miner = miners.TripletMarginMiner(margin=cfg.miner.margin)
    if cfg.miner.name == "multi_similarity":
        miner = miners.MultiSimilarityMiner(epsilon=cfg.miner.epsilon)
        #miner = miners.MultiSimilarityMiner(epsilon=0.05)

    batch_size = cfg.trainer.batch_size
    num_epochs = cfg.trainer.num_epochs
    iterations_per_epoch = cfg.trainer.iterations_per_epoch
    # Set the dataloader sampler
    sampler = samplers.MPerClassSampler(train_dataset.targets, m=4, length_before_new_iter=len(train_dataset))


    # Package the above stuff into dictionaries.
    models = {"trunk": trunk, "embedder": embedder, "classifier": classifier}
    optimizers = {"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer, "classifier_optimizer": classifier_optimizer}
    loss_funcs = {"metric_loss": loss, "classifier_loss": classification_loss}
    mining_funcs = {"tuple_miner": miner}

    # We can specify loss weights if we want to. This is optional
    loss_weights = {"metric_loss": cfg.loss.metric_loss, "classifier_loss": cfg.loss.classifier_loss}


    schedulers = {
            #"metric_loss_scheduler_by_epoch": torch.optim.lr_scheduler.StepLR(classifier_optimizer, cfg.scheduler.step_size, gamma=cfg.scheduler.gamma),
            "embedder_scheduler_by_epoch": torch.optim.lr_scheduler.StepLR(embedder_optimizer, cfg.scheduler.step_size, gamma=cfg.scheduler.gamma),
            "classifier_scheduler_by_epoch": torch.optim.lr_scheduler.StepLR(classifier_optimizer, cfg.scheduler.step_size, gamma=cfg.scheduler.gamma),
            "trunk_scheduler_by_epoch": torch.optim.lr_scheduler.StepLR(embedder_optimizer, cfg.scheduler.step_size, gamma=cfg.scheduler.gamma),
            }

    experiment_name = "%s_model_%s_cl_%s_ml_%s_miner_%s_mix_ml_%02.2f_mix_cl_%02.2f_resize_%d_emb_size_%d_class_size_%d_opt_%s_lr_%02.2f_m_%02.2f_wd_%02.2f"%(cfg.dataset.name,
                                                                                                  cfg.model.model_name, 
                                                                                                  "cross_entropy", 
                                                                                                  cfg.embedder_loss.name, 
                                                                                                  cfg.miner.name, 
                                                                                                  cfg.loss.metric_loss, 
                                                                                                  cfg.loss.classifier_loss,
                                                                                                  cfg.transform.transform_resize,
                                                                                                  cfg.embedder.size,
                                                                                                  cfg.embedder.class_out_size,
                                                                                                  cfg.optimizer.name,
                                                                                                  cfg.optimizer.lr,
                                                                                                  cfg.optimizer.momentum,
                                                                                                  cfg.optimizer.weight_decay)
    record_keeper, _, _ = logging_presets.get_record_keeper("logs/%s"%(experiment_name), "tensorboard/%s"%(experiment_name))
    hooks = logging_presets.get_hook_container(record_keeper)
    dataset_dict = {"samples": val_samples_dataset, "val": val_dataset}
    model_folder = "example_saved_models/%s/"%(experiment_name)

    # Create the tester
    tester = OneShotTester(
            end_of_testing_hook=hooks.end_of_testing_hook, 
            #size_of_tsne=20
            )
    #tester.embedding_filename=data_dir+"/embeddings_pretrained_triplet_loss_multi_similarity_miner.pkl"
    tester.embedding_filename=data_dir+"/"+experiment_name+".pkl"
    end_of_epoch_hook = hooks.end_of_epoch_hook(tester, dataset_dict, model_folder)
    trainer = trainers.TrainWithClassifier(models,
            optimizers,
            batch_size,
            loss_funcs,
            mining_funcs,
            train_dataset,
            sampler=sampler,
            lr_schedulers=schedulers,
            dataloader_num_workers = cfg.trainer.batch_size,
            loss_weights=loss_weights,
            end_of_iteration_hook=hooks.end_of_iteration_hook,
            end_of_epoch_hook=end_of_epoch_hook
            )

    trainer.train(num_epochs=num_epochs)

    tester = OneShotTester()