Skip to content

learning distance metric with siamese CNN to classify sentiment and attempting to show the siamese CNN is robust to blind spots whereas the CNN is not

License

Notifications You must be signed in to change notification settings

xiongshufeng/siamese_sentiment

 
 

Repository files navigation

The task is sentiment classification but the goal of the project is to show that
the CNN that comprises both branches of the siamese network is susceptible to adversial
examples and blind spots like those described in https://arxiv.org/abs/1312.6199 while the
siamese network built from the same CNN is not.

To train model, create the <train> directory inside <training_data>.
The <train> directory should contain sub-directories <pos> and <neg>
containing files with positive reviews and negative reviews from
Stanford sentiment dataset.

If your computer runs out of RAM during training its because of the
number of pairs that must be generated is massive. Set the
TRAIN_LOW_RAM_CUTOFF and DEV_LOW_RAM_CUTOFF paramters inside convert_review module
to prevent this from happening. I have 16GBs of RAM and the current
settings work for me but if you have less you will need to set these
parameters

training is set up with check points so after each epoch the weights will
be saved in the model_data/saved_weights directory and model_specs will
also be saved after training is complete.



EXAMPLE:

from siamese_model import build_siamese_model, train_siamese_model


#this call returns a dictionary containing the siamese model
#and the CNN model comprising the left and right branch of the siamese

models = build_siamese_model()


#this call will train the model. All training params can be set
#inside the siamese_network module (num epochs, CNN params, ect..).
#It returns the training data set, dev set, and training history.

trainingData,devData,hist = train_siamese_model(models)

#once training completes you will have a trained model in
#models['siamese'] and you can use models['siamese'].predict(...)
#with dev set data to test it out. predict input is a list
#of the form [Lreview,Rreview] and these reviews can be found
#in devData. Both devData and trainingData can be unpacked
#like Xdev_left, ydev_left, Xdev_right, ydev_right, dev_similarity = devData

About

learning distance metric with siamese CNN to classify sentiment and attempting to show the siamese CNN is robust to blind spots whereas the CNN is not

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%