def test_input_validation(args, data: st.DataObject):
    kwargs = dict(x1=np.ones((3, 4)), y=mg.Tensor(1), margin=0.5)

    if "x2" not in args:
        args["x2"] = np.ones_like(kwargs["x1"])

    kwargs.update(
        (k, (data.draw(v, label=k)) if isinstance(v, st.SearchStrategy) else v)
        for k, v in args.items())

    with pytest.raises((ValueError, TypeError)):
        margin_ranking_loss(**kwargs)
示例#2
0
def test_ranked_margin(shape, margin, data):
    x1 = data.draw(
        hnp.arrays(shape=shape, dtype=float, elements=st.floats(-1000, 1000)),
        label="x1",
    )
    x2 = data.draw(
        hnp.arrays(shape=shape, dtype=float, elements=st.floats(-1000, 1000)),
        label="x2",
    )
    y = data.draw(
        st.sampled_from((-1, 1))
        | hnp.arrays(
            shape=shape[:1],
            dtype=hnp.integer_dtypes(),
            elements=st.sampled_from((-1, 1)),
        ),
        label="y",
    )

    x1_copy = np.copy(x1)
    x2_copy = np.copy(x2)
    y_copy = np.copy(y)

    x1_dum = mg.Tensor(x1)
    x2_dum = mg.Tensor(x2)

    x1_real = mg.Tensor(x1)
    x2_real = mg.Tensor(x2)

    loss_dum = simple_loss(x1_dum, x2_dum, y, margin)

    loss_real = margin_ranking_loss(x1_real, x2_real, y, margin)

    assert_allclose(actual=loss_real.data,
                    desired=loss_dum.data,
                    err_msg="losses don't match")

    assert_array_equal(x1, x1_copy, err_msg="`x1` was mutated by forward")
    assert_array_equal(x2, x2_copy, err_msg="`x2` was mutated by forward")
    if isinstance(y, np.ndarray):
        assert_array_equal(y, y_copy, err_msg="`y` was mutated by forward")

    loss_dum.backward()
    loss_real.backward()

    assert_allclose(actual=x1_real.grad,
                    desired=x1_dum.grad,
                    err_msg="x1.grad doesn't match")
    assert_allclose(actual=x2_real.grad,
                    desired=x2_dum.grad,
                    err_msg="x2.grad doesn't match")

    assert_array_equal(x1, x1_copy, err_msg="`x1` was mutated by backward")
    assert_array_equal(x2, x2_copy, err_msg="`x2` was mutated by backward")
    if isinstance(y, np.ndarray):
        assert_array_equal(y, y_copy, err_msg="`y` was mutated by backward")

    loss_real.null_gradients()
    assert x1_real.grad is None
    assert x2_real.grad is None
示例#3
0
def train(model, text_emb, good_img, bad_img, optim):
    sim_to_good=sim(text_emb,model(good_img))
    sim_to_bad=sim(text_emb,model(bad_img))
    loss=margin_ranking_loss(sim_to_good,sim_to_bad,1,0.1)
    loss.backward()
    optim.step()
    loss.null_gradients()
    return loss.item(),int(sim_to_good>sim_to_bad)
示例#4
0
def train(model,
          num_epochs,
          margin,
          triplets,
          learning_rate=0.1,
          batch_size=32):
    """ trains the model 
        
        Parameters
        ----------
        
        model -  Model
            an initizized Model class, with input and output dim matching the image ID(512) and the descriptor (50) 
        
        num_epochs - int
            amount of epochs
            
        margin - int
            marhine for the margine ranking loss
            
        triplets 
            triplets created with the data from all_triplets(path)
        
        learning_rate(optional) - int
            learning rate of SDG
            
        batch_size(optional) - int
            the batch size
            

        Returns
        -------
        it trains the model by minimizing the loss function
        
        """
    optim = SGD(model.parameters, learning_rate=learning_rate)
    triplets = load_resnet(r"data\triplets")
    #print(triplets[0:3])
    images = utils.get_img_ids()

    for epoch_cnt in range(num_epochs):
        idxs = np.arange(len(images))
        np.random.shuffle(idxs)

        for batch_cnt in range(0, len(images) // batch_size):

            batch_indices = idxs[batch_cnt * batch_size:(batch_cnt + 1) *
                                 batch_size]
            triplets_batch = [triplets[index] for index in batch_indices]
            #print(triplets_batch[0])

            good_pic_batch = np.array([val[1] for val in triplets_batch])
            bad_pic_batch = np.array([val[2] for val in triplets_batch])
            caption_batch = np.array([val[0] for val in triplets_batch])

            good_pic_pred = model(good_pic_batch)
            bad_pic_pred = model(bad_pic_batch)
            good_pic_pred = good_pic_pred / mg.sqrt(
                mg.sum(mg.power(good_pic_pred, 2), axis=-1, keepdims=True))
            bad_pic_pred = bad_pic_pred / mg.sqrt(
                (mg.sum(mg.power(bad_pic_pred, 2), axis=-1, keepdims=True)))
            #print(good_pic_pred.shape)

            # good_pic_pred = good_pic_pred.reshape(1600, 1, 1)
            # bad_pic_pred = bad_pic_pred.reshape(1600, 1, 1)
            # caption_batch = caption_batch.reshape(1600, 1, 1)

            Sgood = (good_pic_pred * caption_batch).sum(axis=-1)
            Sbad = (bad_pic_pred * caption_batch).sum(axis=-1)
            #print(Sgood.shape, Sbad.shape)
            # Sgood = Sgood.reshape(32, 50)
            # Sbad = Sbad.reshape(32, 50)

            loss = margin_ranking_loss(Sgood, Sbad, 1, margin)
            acc = accuracy(Sgood.flatten(), Sbad.flatten())
            if batch_cnt % 10 == 0:
                print(loss, acc)

            loss.backward()
            optim.step()
            loss.null_gradients()