def test_input_validation(args, data: st.DataObject): kwargs = dict(x1=np.ones((3, 4)), y=mg.Tensor(1), margin=0.5) if "x2" not in args: args["x2"] = np.ones_like(kwargs["x1"]) kwargs.update( (k, (data.draw(v, label=k)) if isinstance(v, st.SearchStrategy) else v) for k, v in args.items()) with pytest.raises((ValueError, TypeError)): margin_ranking_loss(**kwargs)
def test_ranked_margin(shape, margin, data): x1 = data.draw( hnp.arrays(shape=shape, dtype=float, elements=st.floats(-1000, 1000)), label="x1", ) x2 = data.draw( hnp.arrays(shape=shape, dtype=float, elements=st.floats(-1000, 1000)), label="x2", ) y = data.draw( st.sampled_from((-1, 1)) | hnp.arrays( shape=shape[:1], dtype=hnp.integer_dtypes(), elements=st.sampled_from((-1, 1)), ), label="y", ) x1_copy = np.copy(x1) x2_copy = np.copy(x2) y_copy = np.copy(y) x1_dum = mg.Tensor(x1) x2_dum = mg.Tensor(x2) x1_real = mg.Tensor(x1) x2_real = mg.Tensor(x2) loss_dum = simple_loss(x1_dum, x2_dum, y, margin) loss_real = margin_ranking_loss(x1_real, x2_real, y, margin) assert_allclose(actual=loss_real.data, desired=loss_dum.data, err_msg="losses don't match") assert_array_equal(x1, x1_copy, err_msg="`x1` was mutated by forward") assert_array_equal(x2, x2_copy, err_msg="`x2` was mutated by forward") if isinstance(y, np.ndarray): assert_array_equal(y, y_copy, err_msg="`y` was mutated by forward") loss_dum.backward() loss_real.backward() assert_allclose(actual=x1_real.grad, desired=x1_dum.grad, err_msg="x1.grad doesn't match") assert_allclose(actual=x2_real.grad, desired=x2_dum.grad, err_msg="x2.grad doesn't match") assert_array_equal(x1, x1_copy, err_msg="`x1` was mutated by backward") assert_array_equal(x2, x2_copy, err_msg="`x2` was mutated by backward") if isinstance(y, np.ndarray): assert_array_equal(y, y_copy, err_msg="`y` was mutated by backward") loss_real.null_gradients() assert x1_real.grad is None assert x2_real.grad is None
def train(model, text_emb, good_img, bad_img, optim): sim_to_good=sim(text_emb,model(good_img)) sim_to_bad=sim(text_emb,model(bad_img)) loss=margin_ranking_loss(sim_to_good,sim_to_bad,1,0.1) loss.backward() optim.step() loss.null_gradients() return loss.item(),int(sim_to_good>sim_to_bad)
def train(model, num_epochs, margin, triplets, learning_rate=0.1, batch_size=32): """ trains the model Parameters ---------- model - Model an initizized Model class, with input and output dim matching the image ID(512) and the descriptor (50) num_epochs - int amount of epochs margin - int marhine for the margine ranking loss triplets triplets created with the data from all_triplets(path) learning_rate(optional) - int learning rate of SDG batch_size(optional) - int the batch size Returns ------- it trains the model by minimizing the loss function """ optim = SGD(model.parameters, learning_rate=learning_rate) triplets = load_resnet(r"data\triplets") #print(triplets[0:3]) images = utils.get_img_ids() for epoch_cnt in range(num_epochs): idxs = np.arange(len(images)) np.random.shuffle(idxs) for batch_cnt in range(0, len(images) // batch_size): batch_indices = idxs[batch_cnt * batch_size:(batch_cnt + 1) * batch_size] triplets_batch = [triplets[index] for index in batch_indices] #print(triplets_batch[0]) good_pic_batch = np.array([val[1] for val in triplets_batch]) bad_pic_batch = np.array([val[2] for val in triplets_batch]) caption_batch = np.array([val[0] for val in triplets_batch]) good_pic_pred = model(good_pic_batch) bad_pic_pred = model(bad_pic_batch) good_pic_pred = good_pic_pred / mg.sqrt( mg.sum(mg.power(good_pic_pred, 2), axis=-1, keepdims=True)) bad_pic_pred = bad_pic_pred / mg.sqrt( (mg.sum(mg.power(bad_pic_pred, 2), axis=-1, keepdims=True))) #print(good_pic_pred.shape) # good_pic_pred = good_pic_pred.reshape(1600, 1, 1) # bad_pic_pred = bad_pic_pred.reshape(1600, 1, 1) # caption_batch = caption_batch.reshape(1600, 1, 1) Sgood = (good_pic_pred * caption_batch).sum(axis=-1) Sbad = (bad_pic_pred * caption_batch).sum(axis=-1) #print(Sgood.shape, Sbad.shape) # Sgood = Sgood.reshape(32, 50) # Sbad = Sbad.reshape(32, 50) loss = margin_ranking_loss(Sgood, Sbad, 1, margin) acc = accuracy(Sgood.flatten(), Sbad.flatten()) if batch_cnt % 10 == 0: print(loss, acc) loss.backward() optim.step() loss.null_gradients()