def __init__(self, *, memory: CILMemory, margin: float, coef: float):
        super().__init__()
        self.coef = coef
        self.memory = memory
        self.criterion = TripletMarginLoss(margin=margin)

        self._penalty: TripletLossPenalty = None
Exemplo n.º 2
0
def get_default_loss_function():
    """
    Creates a loss function object with predefined default settings.

    :return: Returns loss_func - TripletMarginLoss -
    object with default parameters (margin=10.0, p=2, reduction='sum').
    """
    return TripletMarginLoss(margin=20.0, p=2, reduction='sum')
Exemplo n.º 3
0
def triplet_loss(anchor, positive, negative):
    # return torch.sum(torch.sum((anchor - positive).pow(2), 1) - torch.sum((anchor - negative).pow(2), 1))
    # anchor = normalize(anchor)
    # positive = normalize(positive)
    # negative = normalize(negative)
    anchor = anchor.view(anchor.shape[0], -1)
    positive = positive.view(positive.shape[0], -1)
    negative = negative.view(negative.shape[0], -1)
    return TripletMarginLoss(reduction="sum")(anchor, positive,
                                              negative) / anchor.shape[0]
Exemplo n.º 4
0
 def __init__(self, margin: float,
              sampler_inbatch: "IInbatchTripletSampler"):
     """
     Args:
         margin: margin value
         sampler_inbatch: sampler for forming triplets inside the batch
     """
     super().__init__()
     self._sampler_inbatch = sampler_inbatch
     self._triplet_margin_loss = TripletMarginLoss(margin=margin)
Exemplo n.º 5
0
    def main():
        """ """
        data_dir = Path.home() / "Data" / "mnist_png"
        train_batch_size = 64
        train_number_epochs = 100
        save_path = PROJECT_APP_PATH.user_data / "models"
        model_name = "triplet_siamese_mnist"
        load_prev = True
        train = False

        img_size = (28, 28)
        model = NLetConvNet(img_size).to(global_torch_device())
        optimiser = optim.Adam(model.parameters(), lr=3e-4)

        if train:
            if load_prev:
                model, optimiser = load_model_parameters(
                    model,
                    optimiser=optimiser,
                    model_name=model_name,
                    model_directory=save_path,
                )

            with TensorBoardPytorchWriter():
                # from draugr.stopping import CaptureEarlyStop

                # with CaptureEarlyStop() as _:
                with IgnoreInterruptSignal():
                    model = train_siamese(
                        model,
                        optimiser,
                        TripletMarginLoss().to(global_torch_device()),
                        train_number_epochs=train_number_epochs,
                        data_dir=data_dir,
                        train_batch_size=train_batch_size,
                        model_name=model_name,
                        save_path=save_path,
                        img_size=img_size,
                    )
            save_model_parameters(
                model,
                optimiser=optimiser,
                model_name=f"{model_name}",
                save_directory=save_path,
            )
        else:
            model = load_model_parameters(model,
                                          model_name=model_name,
                                          model_directory=save_path)
            print("loaded best val")
            stest_many_versus_many(model, data_dir, img_size)
Exemplo n.º 6
0
def get_loss_function(loss_func: typing.Union[str, 'CustomLossFunctions'],
                      **loss_func_args):
    """
    Create a custom loss function with custom arguments.

    :param loss_func: Name of the loss function, e.g. 'TripletMarginLoss', ...
    :param loss_func_args: Custom parameters like margin, p (norm degree), reduction, etc.
    :return: Selected loss function with custom parameters.
    """
    if (loss_func == CustomLossFunctions.TripletMarginLoss) \
            or (loss_func in CustomLossFunctions.TripletMarginLoss.value):
        return TripletMarginLoss(margin=loss_func_args['margin'],
                                 p=loss_func_args['norm_degree'],
                                 reduction=loss_func_args['margin'])
def train(train_loader, neg_gen):
    model = UAEModel(device).to(device)
    if args.warm_up:
        model = load_model(args.max_len, device)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    loss_function = TripletMarginLoss()
    min_loss = np.inf
    for epoch in range(args.epoch):
        print('%d / %d Epoch' % (epoch, args.epoch))
        epoch_loss = train_epoch(train_loader, neg_gen, model, optimizer,
                                 loss_function)
        print(epoch_loss)
        if epoch_loss < min_loss:
            min_loss = epoch_loss
            torch.save(model.state_dict(), 'model.bin')
    return model
Exemplo n.º 8
0
def train(epochs, model, data_loaders, model_path):

    model.to(device)
    #summary(model, (3, 299, 299))

    criterion = TripletMarginLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

    best_model_wts = None
    best_acc = 0.0
    
    # Train the models
    for epoch in range(1,epochs+1):
        print(f'Epoch {epoch} of {epochs}')
    
        model.train()
        
        for inputs, targets in data_loaders['train']:
            # Set mini-batch dataset
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs = dim_reduce(model(inputs))
    
            # Make a derangement of targets so the negative is always different from positive
            negatives = make_derangement(targets)
            loss = criterion(outputs, targets, negatives)
            loss.backward()
            optimizer.step()
    
        scheduler.step()

        image_to_text, text_to_image = get_test_results(model, data_loaders['val'])
        if image_to_text > best_acc:
            best_acc = image_to_text
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(best_model_wts, model_path)
Exemplo n.º 9
0
def retrieval_loss2(output, target, no_negatives: int):
    """Function returns triplet loss obtained from comparing the anchor
    with the corresponding postivie label and with another no_negatives
    negative examples
    Inputs:
        output: tensor containing image embedding
        target: tensor containing text embedding
        no_negatives: integer containing number of negative examples"""
    triplet_loss = TripletMarginLoss(margin=1.0, p=2)
    batch_len = len(target)
    batch_loss = 0
    negative_idx = np.random.randint(0, batch_len, np.min([batch_len, no_negatives]))

    total_examples = no_negatives * batch_len
    batch_example_dim = len(output[0])
    # import pdb; pdb.set_trace()
    for k, output_example in enumerate(output):
        output_example_reshaped = output_example.reshape([1, batch_example_dim])
        for idx in negative_idx:
            if idx != k:
                batch_loss += triplet_loss(output_example_reshaped, target[k], target[idx])
            else:
                total_examples -= 1
    return batch_loss / total_examples
Exemplo n.º 10
0
        return triplet_embeddings


if __name__ == '__main__':
    triplet_5033_dataset = TripletsDataset()
    dataloader = DataLoader(triplet_5033_dataset,
                            batch_size=32,
                            shuffle=True,
                            num_workers=10)
    model = DeepRanking(triplet_5033_dataset, 4096)
    model.float()
    optimizer = SGD(model.parameters(), lr=0.001, momentum=0.9)
    loss_tr = []
    big_l = []
    total_step = len(dataloader)
    epochs = range(6)
    for epoch in epochs:
        for i, batch in enumerate(dataloader):
            print(i, end='\r')
            for triplet in zip(*batch):
                optimizer.zero_grad()
                output = model(triplet)
                triplet_loss = TripletMarginLoss(margin=1.0, p=2)
                loss = triplet_loss(*output)
                loss.backward()
                loss_tr.append(loss.item())
                optimizer.step()
                print(
                    f"Epoch [{epoch + 1}/{len(epochs)}], Step [{i + 1}/{total_step}]"
                    f" Loss: {loss.item()}")
Exemplo n.º 11
0
def train(model,
          train_set,
          test_set,
          epochs_,
          learning_rate,
          model_name: str,
          cuda=True,
          start_epoch_=0,
          adam: bool = True,
          patience: int = 5) -> (list, list):
    """
    Train a siamese network with Adam or RMSProp optimizer and contrastive loss.
    :param model: torch.nn.Module
        The Pytorch model to train
    :param train_set: DataLoader
        The data to train the model
    :param test_set: DataLoader
        The data to test the efficiency of the model
    :param epochs_: int
        The number of epochs to train. If -1, it will train until the early stopping stops the training
    :param learning_rate: float
        Starting learning rate used during the train phase
    :param model_name: str
        Specify the model name, just to save the files correctly
    :param cuda: bool
        If the model is in cuda
    :param start_epoch_: int
        To continue the training. Indicates in what epoch the current train will start.
    :param adam: bool
        If true, uses Adam as optimizer, if false uses RMSProp
    :param patience: int
        The number of epochs used in the early stopping
    :return: (list, list)
        The train and test losses
    """
    losses_ = []
    test_losses_ = []
    best_score = float("inf")
    best_model = None
    not_improved = 0
    until_converge = False

    if adam:
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    else:
        optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
    loss = TripletMarginLoss()

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=10,
                                                   gamma=0.1)

    if start_epoch_ != 0:
        lr_scheduler.load_state_dict(
            torch.load(f"./models/lr_scheduler_{start_epoch_ - 1}.pt"))

    if epochs_ == -1:
        epochs_ = start_epoch_ + 1
        until_converge = True

    epoch = start_epoch_
    while epoch < epochs_:
        sys.stdout.flush()
        epoch_loss = 0

        for anchor, positive, negative in tqdm(train_set,
                                               total=len(train_set),
                                               desc='Train'):
            if cuda:
                anchor = anchor.cuda()
                positive = positive.cuda()
                negative = negative.cuda()

            optimizer.zero_grad()

            anchor_vector = model.model(anchor)
            positive_vector = model.model(positive)
            negative_vector = model.model(negative)

            current_loss = loss(anchor_vector, positive_vector,
                                negative_vector)
            current_loss.backward()

            epoch_loss += current_loss.detach().cpu().item()

            optimizer.step()

        # Test step
        sys.stdout.flush()
        test_loss = 0
        with torch.no_grad():
            for anchor, positive, negative in tqdm(test_set,
                                                   total=len(test_set),
                                                   desc="Test"):
                if cuda:
                    anchor = anchor.cuda()
                    positive = positive.cuda()
                    negative = negative.cuda()

                anchor_vector = model.model(anchor)
                positive_vector = model.model(positive)
                negative_vector = model.model(negative)
                test_loss += loss(anchor_vector, positive_vector,
                                  negative_vector).detach().cpu().item()

        epoch_loss /= len(train_set)
        test_loss /= len(test_set)
        losses_.append(epoch_loss)
        test_losses_.append(test_loss)

        message = f'Epoch: {epoch + 1}/{epochs_}, Loss: {epoch_loss:.4f}, ' + \
                  f'Test loss: {test_loss:.4f}\n'
        sys.stdout.write(message)

        with open(f'./train_{model_name}.log', 'a') as f:
            f.write(message)
            f.close()

        torch.save(model.state_dict(),
                   f'./models/triplet/{model_name}_{epoch}.pt')
        torch.save(lr_scheduler.state_dict(),
                   f'./models/triplet/lr_scheduler_{epoch}.pt')

        # Get the best model
        if test_loss < best_score:
            best_score = test_loss
            best_model = model.state_dict()
            not_improved = 0

        else:
            not_improved += 1

        # Early stopping
        if not_improved == patience:
            break

        epoch += 1
        if until_converge:
            epochs_ += 1

    torch.save(best_model, './models/triplet/best-' + model_name + '.pt')
    return losses_, test_losses_
Exemplo n.º 12
0
def test_triplet_loss():

    from torch.nn import TripletMarginLoss
    tl = TripletMarginLoss(margin=1.0, p=2)

    # test with 2 datasets
    n_batch = 6
    n_dims = 3
    x = torch.zeros((n_batch, n_dims))
    y = torch.ones((n_batch, n_dims))
    datasets = np.concatenate([np.zeros((n_batch,)), np.ones((n_batch,))])
    loss = losses.triplet_loss(tl, torch.cat([x, y], 0), datasets)
    assert np.isclose(loss.item(), 0, atol=1e-5)

    x = torch.zeros((n_batch, n_dims))
    y = 2 * torch.ones((n_batch, n_dims))
    datasets = np.concatenate([np.zeros((n_batch,)), np.ones((n_batch,))])
    loss = losses.triplet_loss(tl, torch.cat([x, y], 0), datasets)
    assert np.isclose(loss.item(), 0, atol=1e-5)

    t1 = 0.50
    x = torch.zeros((n_batch, n_dims))
    y = t1 * torch.ones((n_batch, n_dims))
    datasets = np.concatenate([np.zeros((n_batch,)), np.ones((n_batch,))])
    loss = losses.triplet_loss(tl, torch.cat([x, y], 0), datasets)
    val = (-np.sqrt(n_dims * t1 ** 2) + 1) * 2 / 3
    assert np.isclose(loss.item(), val, atol=1e-5)

    # test with 3 datasets
    t1 = 0.25
    t2 = 0.50
    x = torch.zeros((n_batch, n_dims))
    y = t1 * torch.ones((n_batch, n_dims))
    z = t2 * torch.ones((n_batch, n_dims))
    datasets = np.concatenate([np.zeros((n_batch,)), np.ones((n_batch,)), 2 * np.ones((n_batch,))])
    loss = losses.triplet_loss(tl, torch.cat([x, y, z], 0), datasets)
    val1 = (-np.sqrt(n_dims * t1 ** 2) + 1)
    val2 = (-np.sqrt(n_dims * t2 ** 2) + 1)
    val = (4 * val1 + 2 * val2) / 6
    assert np.isclose(loss.item(), val, atol=1e-5)

    # test with 4 datasets
    n_batch = 9
    t1 = 0.1
    t2 = 0.2
    t3 = 0.3
    x = torch.zeros((n_batch, n_dims))
    y = t1 * torch.ones((n_batch, n_dims))
    z = t2 * torch.ones((n_batch, n_dims))
    v = t3 * torch.ones((n_batch, n_dims))
    datasets = np.concatenate(
        [np.zeros((n_batch,)), np.ones((n_batch,)), 2 * np.ones((n_batch,)),
         3 * np.ones((n_batch,))])
    loss = losses.triplet_loss(tl, torch.cat([x, y, z, v], 0), datasets)
    val1 = (-np.sqrt(n_dims * t1 ** 2) + 1)
    val2 = (-np.sqrt(n_dims * t2 ** 2) + 1)
    val3 = (-np.sqrt(n_dims * t3 ** 2) + 1)
    val = (6 * val1 + 4 * val2 + 2 * val3) / 12
    print(val)
    print(loss.item())
    assert np.isclose(loss.item(), val, atol=1e-5)
Exemplo n.º 13
0
 def loss(self):
     """ """
     # return TripletLoss(margin=self.margin, reduction=self.loss_reduction)
     return TripletMarginLoss(margin=self.margin, p=2.0)
Exemplo n.º 14
0
def getPredictionLossFn(cl=None, net=None):
    kldivLoss = KLDivLoss()
    mseLoss = MSELoss()
    smoothl1Loss = SmoothL1Loss()
    tripletLoss = TripletMarginLoss()  #TripletLoss()
    cosineLoss = CosineEmbeddingLoss(margin=0.5)
    if PREDICTION_LOSS == 'MSE':

        def prediction_loss(predFeature, nextFeature):
            return mseLoss(predFeature, nextFeature)
    elif PREDICTION_LOSS == 'SMOOTHL1':

        def prediction_loss(predFeature, nextFeature):
            return smoothl1Loss(predFeature, nextFeature)
    elif PREDICTION_LOSS == 'TRIPLET':

        def prediction_loss(predFeature,
                            nextFeature,
                            negativeFeature=None,
                            cl=cl,
                            net=net):
            if not negativeFeature:
                negatives, _, _ = cl.randomSamples(1)  #predFeature.size(0))
                negativeFeature = net(
                    Variable(negatives[0], requires_grad=False).cuda(),
                    Variable(negatives[1],
                             requires_grad=False).cuda()).detach()
            return tripletLoss(predFeature.unsqueeze(0),
                               nextFeature.unsqueeze(0), negativeFeature)
    elif PREDICTION_LOSS == 'COSINE':

        def prediction_loss(predFeature,
                            nextFeature,
                            negativeFeature=None,
                            cl=cl,
                            net=net):
            if not negativeFeature:
                negatives, _, _ = cl.randomSamples(1)  #predFeature.size(0))
                negativeFeature = net(
                    Variable(negatives[0], requires_grad=False).cuda(),
                    Variable(negatives[1],
                             requires_grad=False).cuda()).detach()
            else:
                negativeFeature = negativeFeature.unsqueeze(0)
            predFeature = predFeature.unsqueeze(0)
            nextFeature = nextFeature.unsqueeze(0)
            # concat positive and negative features
            # create targets for concatenated positives and negatives
            input1 = torch.cat([predFeature, predFeature], dim=0)
            input2 = torch.cat([nextFeature, negativeFeature], dim=0)
            target1 = Variable(torch.ones(predFeature.size(0)),
                               requires_grad=False).detach().cuda()
            target2 = -target1
            target = torch.cat([target1, target2], dim=0)
            return cosineLoss(input1, input2, target)
    else:

        def prediction_loss(predFeature, nextFeature):
            return kldivLoss(F.log_softmax(predFeature),
                             F.softmax(nextFeature))

    return prediction_loss
Exemplo n.º 15
0
def train_autoencoder(
    autoencoder,
    train_dataloader,
    val_dataloader,
    checkpoint_file,
    tb_writer,
    triplet_loss_weight: float = 0.0,
    num_epochs: int = 40,
    deep_supervision: bool = False,
    learning_rate: float = 0.01,
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    autoencoder.to(device)
    bce_loss = WeightedBinaryCrossEntropyLoss(beta=0.9).to(device)
    if triplet_loss_weight != 0.0:
        trplt_loss = TripletMarginLoss(swap=True).to(device)
    lowest_validation_loss = 100
    optimizer = torch.optim.Adam(autoencoder.parameters(),
                                 lr=learning_rate,
                                 weight_decay=0.0001)
    if os.path.isfile(checkpoint_file):
        ckpt = torch.load(checkpoint_file, map_location=device)
        autoencoder.load_state_dict(ckpt["state_dict"])
        autoencoder.global_step = ckpt["global_step"]
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        lowest_validation_loss = ckpt["lowest_val_loss"]
        print(f"loaded checkpoint {checkpoint_file}. "
              f"Starting from {autoencoder.global_step}")

    while autoencoder.global_step < num_epochs:
        losses = {"train": [], "val": []}
        ious = {"train": [], "val": []}
        autoencoder.train()
        for batch_idx, data in enumerate(tqdm(train_dataloader)):
            optimizer.zero_grad()
            predictions = autoencoder(data["reference"].to(device),
                                      intermediate_outputs=deep_supervision)
            if not deep_supervision:
                loss = bce_loss(predictions, data["mask"].to(device))
            else:
                loss = 0
                for output in predictions:
                    loss += (1 / len(predictions) *
                             bce_loss(output, data["mask"].to(device)))
            if triplet_loss_weight != 0:
                anchor = autoencoder.project(data["reference"].to(device))
                positive = autoencoder.project(data["positive"].to(device))
                negative = autoencoder.project(data["negative"].to(device))
                triplet = triplet_loss_weight * trplt_loss(
                    anchor, positive, negative)
                # print(
                # f"loss: {loss} vs trplt: {triplet} (weight: {triplet_loss_weight})"
                # )
                loss += triplet
            loss.backward()
            optimizer.step()
            losses["train"].append(loss.cpu().detach().item())
            ious["train"].append(
                get_IoU(
                    autoencoder(data["reference"].to(device),
                                intermediate_outputs=False),
                    data["mask"].to(device),
                ).detach().cpu().item())
        autoencoder.eval()
        for batch_idx, data in enumerate(val_dataloader):
            predictions = autoencoder(data["reference"].to(device),
                                      intermediate_outputs=False)
            loss = bce_loss(predictions, data["mask"].to(device))
            losses["val"].append(loss.cpu().detach().item())
            ious["val"].append(
                get_IoU(predictions,
                        data["mask"].to(device)).detach().cpu().item())

        print(
            f"{get_date_time_tag()}: epoch {autoencoder.global_step} - "
            f"train {np.mean(losses['train']): 0.3f} [{np.std(losses['train']): 0.2f}]"
            f" - val {np.mean(losses['val']): 0.3f} [{np.std(losses['val']): 0.2f}]"
        )
        tb_writer.add_scalar(
            "train/bce_loss/autoencoder" +
            ("" if not triplet_loss_weight else "_trplt"),
            np.mean(losses["train"]),
            global_step=autoencoder.global_step,
        )
        tb_writer.add_scalar(
            "val/bce_loss/autoencoder" +
            ("" if not triplet_loss_weight else "_trplt"),
            np.mean(losses["val"]),
            global_step=autoencoder.global_step,
        )
        tb_writer.add_scalar(
            "train/iou/autoencoder",
            np.mean(ious["train"]),
            global_step=autoencoder.global_step,
        )
        tb_writer.add_scalar(
            "val/iou/autoencoder",
            np.mean(ious["val"]),
            global_step=autoencoder.global_step,
        )
        autoencoder.global_step += 1
        if lowest_validation_loss > np.mean(losses["val"]):
            lowest_validation_loss = np.mean(losses["val"])
            ckpt = {
                "state_dict": autoencoder.state_dict(),
                "global_step": autoencoder.global_step,
                "optimizer_state_dict": optimizer.state_dict(),
                "lowest_val_loss": lowest_validation_loss,
            }
            torch.save(
                ckpt,
                checkpoint_file,
            )
            with open(os.path.join(tb_writer.get_logdir(), "results.txt"),
                      "w") as f:
                f.write(
                    f"validation_bce_loss_avg: "
                    f"{np.mean(losses['val']):10.3e}\n", )
                f.write(
                    f"validation_bce_loss_std: "
                    f"{np.std(losses['val']):10.2e}\n", )
                f.write(
                    f"validation_iou_avg: "
                    f"{np.mean(ious['val']):10.3e}\n", )
                f.write(
                    f"validation_iou_std: "
                    f"{np.std(ious['val']):10.2e}\n", )
            print(f"Saved model in {checkpoint_file}.")
    autoencoder.to(torch.device("cpu"))
Exemplo n.º 16
0
def train(model, config):
  # clear_output_dir()
  optimizer, lr_scheduler = init_training(model, config)
  logger = Logger(config)

  # TODO: Check which images it thinks are similar from e.g. copydays.

  # transformer = AllTransformer()
  # transformer = JpgTransformer()
  # transformer = RotateTransformer()
  # transformer = FlipTransformer()
  # transformer = RotateCropTransformer()
  transformer = CropTransformer()
  validator = Validator(model, logger, config, transformer)

  margin = 5
  triplet_loss_fn = TripletMarginLoss(margin,
                                      p=config.distance_norm,
                                      swap=True)
  neg_cos_loss_fn = ZeroCosineLoss(margin=0.1)
  pos_cos_loss_fn = PositiveCosineLoss(margin=0.1)
  similarity_loss_fn = torch.nn.BCELoss()

  # Data
  dataloader = setup_traindata(config, transformer)

  # Init progressbar
  n_batches = len(dataloader)
  n_epochs = math.ceil(config.optim_steps / n_batches)
  pbar = Progressbar(n_epochs, n_batches)

  optim_steps = 0
  val_freq = config.validation_freq

  # Training loop
  for epoch in pbar(range(1, n_epochs + 1)):
    for batch_i, data in enumerate(dataloader, 1):
      pbar.update(epoch, batch_i)

      # Validation
      # if optim_steps % val_freq == 0:
      #   validator.validate(optim_steps)
      print("START")

      # Decrease learning rate
      if optim_steps % config.lr_step_frequency == 0:
        lr_scheduler.step()

      optimizer.zero_grad()
      original, transformed = data

      inputs = torch.cat((original, transformed))
      outputs = model(inputs)
      original_emb, transf_emb = outputs
      anchors, positives, negatives = create_triplets(original_emb, transf_emb)
      print(anchors.shape)

      # Triplet loss
      triplet_loss = triplet_loss_fn(anchors, positives, negatives)
      # anchors, positives = scale_embeddings(anchors, positives, model)
      # anchors, negatives = scale_embeddings(anchors, negatives, model)

      # Cosine similarity loss
      # cos_match_loss = pos_cos_loss_fn(anchors, positives)
      # cos_not_match_loss = neg_cos_loss_fn(anchors, negatives)

      # Direct net loss
      # a_p, a_n = model.cc_similarity_net(anchors, positives, negatives)
      # net_match_loss = similarity_loss_fn(a_p, torch.ones_like(a_p))
      # net_not_match_loss = similarity_loss_fn(a_n, torch.zeros_like(a_n))
      # net_loss = net_match_loss + net_not_match_loss

      # loss_dict = dict(triplet=triplet_loss, cos_pos=cos_match_loss, cos_neg=cos_not_match_loss)
      # loss_dict = dict(cos_pos=cos_match_loss, cos_neg=cos_not_match_loss)
      loss_dict = dict(triplet_loss=triplet_loss)
      # loss_dict = dict(direct_match=net_match_loss,
      #                  direct_not_match=net_not_match_loss)

      loss = sum(loss_dict.values())
      loss.backward()
      optimizer.step()
      optim_steps += 1

      corrects = model.corrects(transf_emb, original_emb)
      logger.easy_or_hard(anchors, positives, negatives, margin, optim_steps)
      logger.log_loss(loss, optim_steps)
      logger.log_loss_percent(loss_dict, optim_steps)
      logger.log_corrects(corrects, optim_steps)
      logger.log_cosine(anchors, positives, negatives)
      # logger.log_p(model.feature_extractor.pool.p, optim_steps)
      # logger.log_weights(model.feature_extractor.sim_weights)

      # Frees up GPU memory
      del data
      del outputs