コード例 #1
0
def main():
    parser = ArgumentParser()

    parser.add_argument('mode',
                        choices=['train', 'predict'],
                        help='pipeline mode')

    parser.add_argument('model',
                        choices=['rnn', 'cnn', 'multihead'],
                        help='model to be used')

    parser.add_argument('dataset',
                        choices=['QQP', 'SNLI'],
                        help='dataset to be used')

    parser.add_argument('--gpu',
                        default='0',
                        help='index of GPU to be used (default: %(default))')

    args = parser.parse_args()

    set_visible_gpu(args.gpu)

    main_config = init_config()
    model_config = init_config(args.model)

    mode = args.mode

    if 'train' in mode:
        train(main_config, model_config, args.model, args.dataset)
    else:
        predict(main_config, model_config, args.model)
コード例 #2
0
def main():
    parser = ArgumentParser()

    parser.add_argument(
        'mode',
        choices=['train', 'predict'],
        help='pipeline mode',
    )

    parser.add_argument(
        'model',
        choices=['rnn', 'cnn', 'multihead'],
        help='model to be used',
    )

    parser.add_argument(
        'dataset',
        choices=['QQP', 'SNLI', 'ANLI'],
        nargs='?',
        help='dataset to be used',
    )

    parser.add_argument(
        '--experiment_name',
        required=False,
        help='the name of run experiment',
    )

    parser.add_argument(
        '--gpu',
        default='0',
        help='index of GPU to be used (default: %(default))',
    )

    args = parser.parse_args()
    if 'train' in args.mode:
        if args.dataset is None:
            parser.error('Positional argument [dataset] is mandatory')
    set_visible_gpu(args.gpu)

    main_config = init_config()
    model_config = init_config(args.model)

    mode = args.mode

    experiment_name = args.experiment_name
    if experiment_name is None:
        experiment_name = create_experiment_name(args.model, main_config,
                                                 model_config)

    if 'train' in mode:
        train(main_config, model_config, args.model, experiment_name,
              args.dataset)
    else:
        predict(main_config, model_config, args.model, experiment_name)
コード例 #3
0
ファイル: train.py プロジェクト: Stacy-D/siamese
def main():
    parser = ArgumentParser()

    parser.add_argument('--data-dir',
                        default='./corpora',
                        help='Path to original quora split')

    parser.add_argument('--model-dir',
                        default='./model_dir',
                        help='Path to save the trained model')

    parser.add_argument('--use-help',
                        choices=[True, False],
                        default=False,
                        type=bool,
                        help='should model use help on difficult examples')

    parser.add_argument('--gpu',
                        default='0',
                        help='index of GPU to be used (default: %(default))')

    parser.add_argument('--embeddings',
                        choices=['no', 'fixed', 'tunable'],
                        default='no',
                        type=str,
                        help='embeddings')
    parser.add_argument('--batch-size',
                        choices=[4, 128, 256, 512],
                        default=128,
                        type=int,
                        help='batch size')
    parser.add_argument('--syn-weight',
                        default=1,
                        type=float,
                        help='Weight for loss function')

    args = parser.parse_args()
    logger.info(args)
    if args.embeddings == 'no':
        args.use_embed = False
        args.tune = False
    else:
        args.use_embed = True
        args.tune = False if args.embeddings == 'fixed' else True

    set_visible_gpu(args.gpu)
    args.model_dir = '{}_bilstm_{}_{}_{}'.format(args.model_dir,
                                                 args.embeddings,
                                                 args.batch_size,
                                                 args.syn_weight
                                                 )

    main_config = init_config()

    args.max_seq_length, args.vocab_size = get_vocab(main_config, args, logger)
    logger.info(args)
    train(main_config, args)
コード例 #4
0
    def load_model(self, model_name):
        if 'multihead' in model_name:
            self.visualize_attentions_checkbox.grid(row=2,
                                                    column=0,
                                                    sticky=W + E,
                                                    ipady=1)
        else:
            self.visualize_attentions_checkbox.grid_forget()
        tf.reset_default_graph()
        self.session = tf.Session()
        logger.info('Loading model: %s', model_name)

        model = MODELS[model_name.split('_')[0]]
        model_config = init_config(model_name.split('_')[0])

        self.model = model(self.max_doc_len, self.vocabulary_size,
                           self.main_config, model_config)
        saver = tf.train.Saver()
        last_checkpoint = tf.train.latest_checkpoint('{}/{}'.format(
            self.model_dir, model_name))
        saver.restore(self.session, last_checkpoint)
        logger.info('Loaded model from: %s', last_checkpoint)
コード例 #5
0
    def __init__(self, master):
        self.frame = master
        self.frame.title('Multihead Siamese Nets')

        sample1 = StringVar(master, value=SAMPLE_SENTENCE1)
        sample2 = StringVar(master, value=SAMPLE_SENTENCE2)
        self.first_sentence_entry = Entry(
            self.frame,
            width=50,
            font="Helvetica {}".format(GUI_FONT_SIZE),
            textvariable=sample1)
        self.second_sentence_entry = Entry(
            self.frame,
            width=50,
            font="Helvetica {}".format(GUI_FONT_SIZE),
            textvariable=sample2)
        self.predictButton = Button(self.frame,
                                    text='Predict',
                                    font="Helvetica {}".format(GUI_FONT_SIZE),
                                    command=self.predict)
        self.clearButton = Button(self.frame,
                                  text='Clear',
                                  command=self.clear,
                                  font="Helvetica {}".format(GUI_FONT_SIZE))
        self.resultLabel = Label(self.frame,
                                 text='Result',
                                 font="Helvetica {}".format(GUI_FONT_SIZE))
        self.first_sentence_label = Label(
            self.frame,
            text='Sentence 1',
            font="Helvetica {}".format(GUI_FONT_SIZE))
        self.second_sentence_label = Label(
            self.frame,
            text='Sentence 2',
            font="Helvetica {}".format(GUI_FONT_SIZE))

        self.main_config = init_config()
        self.model_dir = str(self.main_config['DATA']['model_dir'])

        model_dirs = [os.path.basename(x[0]) for x in os.walk(self.model_dir)]

        self.visualize_attentions = IntVar()
        self.visualize_attentions_checkbox = Checkbutton(
            master,
            text="Visualize attention weights",
            font="Helvetica {}".format(int(GUI_FONT_SIZE / 2)),
            variable=self.visualize_attentions,
            onvalue=1,
            offvalue=0)

        variable = StringVar(master)
        variable.set('Choose a model...')
        self.model_type = OptionMenu(master,
                                     variable,
                                     *model_dirs,
                                     command=self.load_model)
        self.model_type.configure(font=('Helvetica', GUI_FONT_SIZE))

        self.first_sentence_entry.grid(row=0, column=1, columnspan=4)
        self.first_sentence_label.grid(row=0, column=0, sticky=E)
        self.second_sentence_entry.grid(row=1, column=1, columnspan=4)
        self.second_sentence_label.grid(row=1, column=0, sticky=E)
        self.model_type.grid(row=2, column=1, sticky=W + E, ipady=1)
        self.predictButton.grid(row=2, column=2, sticky=W + E, ipady=1)
        self.clearButton.grid(row=2, column=3, sticky=W + E, ipady=1)
        self.resultLabel.grid(row=2, column=4, sticky=W + E, ipady=1)

        self.vectorizer = DatasetVectorizer(self.model_dir)

        self.max_doc_len = self.vectorizer.max_sentence_len
        self.vocabulary_size = self.vectorizer.vocabulary_size

        self.session = tf.Session()
        self.model = None