コード例 #1
0
def train(train_path, test=None, ckpt='./checkpoint/'):
    if not os.path.exists(ckpt):
        os.makedirs(ckpt)

    data = Dataloader(FLAGS['embedding_file'], FLAGS['seq_len'],
                      FLAGS['embedding_dim'])
    if test:
        testo, testc, _ = data.dataload(test)

    # config = tf.ConfigProto(device_count={'gpu': 0})  # only cpu variable
    # tf.Session(config)
    with tf.Session(config=config) as sess:
        model = Lstm_ranking(data.vocab.embedding, FLAGS)
        saver = tf.train.Saver(max_to_keep=5)
        sess.run(tf.global_variables_initializer())

        best_acc = 0.7
        acc_record = []
        for epoch in range(1, FLAGS['num_epochs']):
            # data_q, data_c = data_loader(train_path)
            # rewrite_train(data_q, data_c, './logs/train_data', 5)
            generate_train(train_path, './logs/train_data', 3)
            traino, trainc, trainn = data.dataload('./logs/train_data')
            # train
            stime = time.time()
            for batch in data.batch_iter(list(zip(traino, trainc, trainn)),
                                         FLAGS['batch_size']):
                batch_o, batch_c, batch_n = zip(*batch)
                step, loss = model.train_step(sess, batch_o, batch_c, batch_n)
                if step % 100 == 0:
                    print(
                        'train, epoch: %d, step: %d, loss: %.4f, spend-time: %.4fs'
                        % (epoch, step, loss, time.time() - stime))
                    stime = time.time()

            if test:
                acc, acc_3, index, score = model.dev_step(sess, testo, testc)
                acc_record.append(acc)
                print(
                    "Evaluation   epoch: %d, top1-acc: %.4f, top3-acc: %.4f" %
                    (epoch, acc, acc_3))
                if acc > best_acc - 0.003:
                    saver.save(
                        sess,
                        os.path.join(
                            ckpt, 'lstm-model-' + str(epoch) + '-' +
                            str(acc * 100)[:5] + '.ckpt'))
                    write_res(test, index, score,
                              './logs/result-lstm/epoch-' + str(epoch))

                    best_acc = acc
            else:
                saver.save(
                    sess,
                    os.path.join(ckpt, 'lstm-model-' + str(epoch) + '.ckpt'))

        return acc_record
コード例 #2
0
def test(config):
    import sent2vec
    assert config.sent2vec.model is not None, "Please add sent2vec_model config value."
    sent2vec_model = sent2vec.Sent2vecModel()
    sent2vec_model.load_model(config.sent2vec.model)

    output_fn_test = OutputFnTest(sent2vec_model, config)

    test_set = Dataloader(config, 'data/test_stories.csv', testing_data=True)
    test_set.load_dataset('data/test.bin')
    test_set.load_vocab('./data/default.voc', config.vocab_size)
    test_set.set_output_fn(output_fn_test)

    generator_testing = test_set.get_batch(config.batch_size,
                                           config.n_epochs,
                                           random=True)

    keras_model = keras.models.load_model(
        './builds/leonhard/2018-06-08 12:04:03-entailmentv6_checkpoint_epoch-85.hdf5'
    )

    verbose = 0 if not config.debug else 1

    # test_batch = next(generator_testing)
    print(keras_model.metrics_names)
    loss = keras_model.evaluate_generator(generator_testing,
                                          steps=len(test_set) /
                                          config.batch_size,
                                          verbose=verbose)
    print(loss)
コード例 #3
0
    def test(self):
        # Initialize tensorflow session
        sess = tf.Session()
        K.set_session(sess)  # Set to keras backend

        if self.config.debug:
            print('Importing Elmo module...')
        if self.config.hub.is_set("cache_dir"):
            os.environ['TFHUB_CACHE_DIR'] = self.config.hub.cache_dir

        elmo_model = hub.Module("https://tfhub.dev/google/elmo/1",
                                trainable=True)
        if self.config.debug:
            print('Imported.')

        # If we gave the models to the encoder decodes...
        self.use_pretrained_models = self.config.alignment.is_set(
            'decoder_target_model') and self.config.alignment.is_set(
                'decoder_src_model') and self.config.alignment.is_set(
                    'encoder_target_model') and self.config.alignment.is_set(
                        'encoder_src_model')

        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())

        self.graph = tf.get_default_graph()

        elmo_emb_fn = ElmoEmbedding(elmo_model)

        elmo_embeddings = keras.layers.Lambda(elmo_emb_fn,
                                              output_shape=(1024, ))
        sentence = keras.layers.Input(shape=(1, ), dtype="string")
        sentence_emb = elmo_embeddings(sentence)

        self.elmo_model = keras.models.Model(inputs=sentence,
                                             outputs=sentence_emb)

        test_set = Dataloader(self.config,
                              'data/test_stories.csv',
                              testing_data=True)
        test_set.load_dataset('data/test.bin')
        test_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.set_output_fn(self.output_fn_test)

        generator_test = test_set.get_batch(self.config.batch_size,
                                            self.config.n_epochs)

        model = keras.models.load_model(self.config.alignment.final_model)

        print(model.metrics_names)
        acc_targets = []
        acc_srcs = []
        for inputs, labels in generator_test:
            results = model.evaluate(inputs, labels, batch_size=len(inputs))
            acc_target, acc_src = results[-4], results[-5]
            acc_targets.append(acc_target)
            acc_srcs.append(acc_src)
            print(np.mean(acc_targets), np.mean(acc_srcs))
コード例 #4
0
ファイル: train.py プロジェクト: luofeng1994/GeolifeData
def main():
    args = dict()
    args['corpus_path'] = '../vector_traing_corpus/clusterTrajectory_100m_5minutes_50eps_5minPts_uniq_2windows.txt'
    args['train_data_path'] = '../vector_traing_corpus/clusterTrajectory_100m_5minutes_50eps_5minPts_uniq_2windows.txt'
    args['dbscanner_path'] = '../utils/dbscaner_100m_5minutes_50eps_5minPts.pkl'
    args['batch_size'] = 1
    args['neg_sample_size'] = 5
    args['alpha'] = 0.75  # smooth out unigram frequencies
    args['table_size'] = 1000  # table size from which to sample neg samples
    args['min_frequency'] = 1  # threshold for vocab frequency
    args['lr'] = 0.05
    args['min_lr'] = 0.005
    args['embed_size'] = 128
    args['sampling'] = False
    args['epoches'] = 70
    args['save_every_n'] = 200
    args['save_dir'] = './save_windowns{}'.format(args['neg_sample_size'])
    dataloader = Dataloader(args)
    args['vocab_size'] = dataloader.vocab_size
    pickle.dump(dataloader, open('./variable/dataloader.pkl', 'w'))
    pickle.dump(args, open('./variable/args.pkl', 'w'))

    model = Model(args)
    saver = tf.train.Saver(max_to_keep=10)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        summary_writer = tf.summary.FileWriter('./log', sess.graph)
        count = 0
        for e in range(args['epoches']):
            dataloader.reset_pointer()
            for i in range(dataloader.batch_num):
                count += 1
                start = time.time()
                x, y = dataloader.next_batch()
                # labels = dataloader.labels
                args['lr'] = args['lr'] + dataloader.decay
                feed = {model.x: x,
                        model.y: y,
                        model.lr: args['lr']}
                summary, train, loss = sess.run([model.merged, model.train, model.loss], feed_dict=feed)
                summary_writer.add_summary(summary, count)
                end = time.time()
                if count % 100 == 0:
                    print('round_num: {}/{}... '.format(e + 1, args['epoches']),
                          'Training steps: {}... '.format(count),
                          'Training error: {:.4f}... '.format(loss),
                          'Learning rate: {:.4f}... '.format(args['lr']),
                          '{:.4f} sec/batch'.format((end - start)))
                if (count % args['save_every_n'] == 0):
                    saver.save(sess, "{path}/i{counter}.ckpt".format(path = args['save_dir'], counter=count))
        saver.save(sess, "{path}/i{counter}.ckpt".format(path=args['save_dir'], counter=count))
        summary_writer.close()
コード例 #5
0
def main(config):
    import sent2vec
    assert config.sent2vec.model is not None, "Please add sent2vec_model config value."
    sent2vec_model = sent2vec.Sent2vecModel()
    sent2vec_model.load_model(config.sent2vec.model)

    preprocess_fn = Preprocess(sent2vec_model)

    output_fn_test = OutputFnTest(sent2vec_model, config)

    train_set = SNLIDataloaderPairs('data/snli_1.0/snli_1.0_train.jsonl')
    train_set.set_preprocess_fn(preprocess_fn)
    train_set.set_output_fn(output_fn)

    test_set = Dataloader(config, 'data/test_stories.csv', testing_data=True)
    test_set.load_dataset('data/test.bin')
    test_set.load_vocab('./data/default.voc', config.vocab_size)
    test_set.set_output_fn(output_fn_test)
    # dev_set = SNLIDataloader('data/snli_1.0/snli_1.0_dev.jsonl')
    # dev_set.set_preprocess_fn(preprocess_fn)
    # dev_set.set_output_fn(output_fn)
    # test_set = SNLIDataloader('data/snli_1.0/snli_1.0_test.jsonl')

    generator_training = train_set.get_batch(config.batch_size,
                                             config.n_epochs)
    generator_dev = test_set.get_batch(config.batch_size, config.n_epochs)

    keras_model = model(config)

    verbose = 0 if not config.debug else 1
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    # Callbacks
    tensorboard = keras.callbacks.TensorBoard(log_dir='./logs/' + timestamp +
                                              '-entailmentv5/',
                                              histogram_freq=0,
                                              batch_size=config.batch_size,
                                              write_graph=False,
                                              write_grads=True)

    model_path = os.path.abspath(
        os.path.join(os.curdir, './builds/' + timestamp))
    model_path += '-entailmentv5_checkpoint_epoch-{epoch:02d}.hdf5'

    saver = keras.callbacks.ModelCheckpoint(model_path,
                                            monitor='val_loss',
                                            verbose=verbose,
                                            save_best_only=True)

    keras_model.fit_generator(generator_training,
                              steps_per_epoch=300,
                              epochs=config.n_epochs,
                              verbose=verbose,
                              validation_data=generator_dev,
                              validation_steps=len(test_set) /
                              config.batch_size,
                              callbacks=[tensorboard, saver])
コード例 #6
0
def percentage_data():
    percentage = request.args.get('percentage', '')
    print("percent:", percentage)
    dataloader = Dataloader.getInstance()
    #dataloader.build_dataframes(False)
    json_data = dataloader.get_percentage_data(int(percentage))
    json_data = dataloader.append_alg_params(json_data)
    return json_data
コード例 #7
0
def graph_init():
    interested_dataset = './data/crowds/students003.txt'
    dataloader = Dataloader.getInstance()
    dataloader.build_dataframes(True, 30)
    json_data = dataloader.retrieve_df(interested_dataset)
    dataloader.retrieved_dataset = interested_dataset
    json_data = dataloader.append_alg_params(json_data)
    return render_template("index.html", json_data=json_data)
コード例 #8
0
def test(CycleGAN, model_path, batch_size=1):
    print('Load Model:')
    netG = load_model(model_path)
    dateloader = Dataloader()
    for batch_i, (imgs_X, imgs_Y) in enumerate(
            dataloader.load_batch(batch_size, for_testing=True)):
        preds = netG.predict()
        pass
    pass
コード例 #9
0
def main(config):
    train_set = SNLIDataloader('data/snli_1.0/snli_1.0_train.jsonl')
    train_set.set_preprocess_fn(preprocess_fn)
    train_set.set_output_fn(output_fn)
    dev_set = Dataloader(config, 'data/test_stories.csv', testing_data=True)
    # dev_set.set_preprocess_fn(preprocess_fn)
    dev_set.load_dataset('data/test.bin')
    dev_set.load_vocab('./data/default.voc', config.vocab_size)
    dev_set.set_output_fn(output_fn_test)
    # test_set = SNLIDataloader('data/snli_1.0/snli_1.0_test.jsonl')

    generator_training = train_set.get_batch(config.batch_size,
                                             config.n_epochs)
    generator_dev = dev_set.get_batch(config.batch_size, config.n_epochs)

    # Initialize tensorflow session
    sess = tf.Session()
    K.set_session(sess)  # Set to keras backend

    keras_model = model(sess, config)
    print(keras_model.summary())

    verbose = 0 if not config.debug else 1
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    # Callbacks
    tensorboard = keras.callbacks.TensorBoard(log_dir='./logs/' + timestamp +
                                              '-entailmentv1/',
                                              histogram_freq=0,
                                              batch_size=config.batch_size,
                                              write_graph=False,
                                              write_grads=True)

    model_path = os.path.abspath(
        os.path.join(os.curdir, './builds/' + timestamp))
    model_path += '-entailmentv1_checkpoint_epoch-{epoch:02d}.hdf5'

    saver = keras.callbacks.ModelCheckpoint(model_path,
                                            monitor='val_loss',
                                            verbose=verbose,
                                            save_best_only=True)

    keras_model.fit_generator(generator_training,
                              steps_per_epoch=300,
                              epochs=config.n_epochs,
                              verbose=verbose,
                              validation_data=generator_dev,
                              validation_steps=len(dev_set) /
                              config.batch_size,
                              callbacks=[tensorboard, saver])
コード例 #10
0
def main(args):
    es = Elasticsearch(f'es:{args.p}')
    create_index(es, args.index)
    image_dir = os.path.join(os.path.dirname(__file__), 'images')
    loader = Dataloader(image_dir, 32)
    fe = FeatureExtractor()

    for i in tqdm(range(len(loader))):
        path, image = loader.__getitem__(i)
        vector = fe.predict(image)
        docs = [{
            '_index': args.index,
            '_source': {
                'path': str(p),
                'vector': list(v)
            }
        } for p, v in zip(path, vector)]
        helpers.bulk(es, docs)

    print("Preparing complete")
コード例 #11
0
ファイル: test.py プロジェクト: luofeng1994/GeolifeData
def main():
    args = pickle.load(open('./utils/args.pkl', 'r'))
    args['is_training'] = False
    # args['save_dir'] = './save'
    # args['train_data_path'] = '../predict_training_corpus/clusterTrajectory_100m_5minutes_50eps_5minPts_uniq_5numseqs_train.txt'
    args[
        'test_data_path'] = '../predict_training_corpus/clusterTrajectory_100m_5minutes_50eps_5minPts_uniq_3numseqs_train.txt'
    args['num_steps'] = 3

    dataloader = Dataloader(args)
    model = Model(args)

    count = 0
    count_correct = 0
    with tf.Session() as sess:
        saver = tf.train.Saver(tf.all_variables())
        ckpt = tf.train.get_checkpoint_state(args['save_dir'])
        if ckpt and ckpt.model_checkpoint_path:
            for row in dataloader.data:
                saver.restore(sess, ckpt.model_checkpoint_path)
                new_state = sess.run(model.initial_state)
                datas = row['data']
                labels = int(row['label'])
                x = dataloader.create_input(datas)
                feed = {model.input: x, model.initial_state: new_state}
                prediction = sess.run(model.prediction, feed_dict=feed)
                prediction = sess.run(tf.argmax(prediction, axis=1))
                prediction = prediction[0]

                count += 1
                if labels == -1:
                    labels = 0
                label_pred = prediction
                print '{}:{}->{}'.format(datas, labels, label_pred)
                if labels == label_pred:
                    count_correct += 1
    print 'totoal count:{}, correct count:{}, correct_rate:{}'.format(
        count, count_correct, count_correct / count)
コード例 #12
0
def run_dbcsan():
    pts = request.args.get('pts', '')
    spatial = request.args.get('spatial', '')
    temporal = request.args.get('temporal', '')
    velocity = request.args.get('velocity', '')
    percent = request.args.get('percent', '')
    print("received: ", pts, spatial, temporal, velocity, percent)
    dataloader = Dataloader.getInstance()
    json_data = dataloader.run_dbscan_with_params(float(pts), float(spatial),
                                                  float(temporal),
                                                  float(velocity),
                                                  float(percent))
    json_data = dataloader.append_alg_params(json_data)
    return json_data
コード例 #13
0
    def eval(self):
        # Initialize tensorflow session
        sess = tf.Session()
        K.set_session(sess)  # Set to keras backend

        if self.config.debug:
            print('Importing Elmo module...')
        if self.config.hub.is_set("cache_dir"):
            os.environ['TFHUB_CACHE_DIR'] = self.config.hub.cache_dir

        elmo_model = hub.Module("https://tfhub.dev/google/elmo/1",
                                trainable=True)
        if self.config.debug:
            print('Imported.')

        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())

        graph = tf.get_default_graph()

        elmo_emb_fn = ElmoEmbedding(elmo_model)

        elmo_model_emb = get_elmo_embedding(elmo_emb_fn)

        type_translation_model = keras.models.load_model(
            self.config.type_translation_model)

        output_fn = OutputFN(elmo_model_emb, type_translation_model, graph)

        test_set = Dataloader(self.config,
                              'data/test_stories.csv',
                              testing_data=True)
        # test_set.set_preprocess_fn(preprocess_fn)
        test_set.load_dataset('data/test.bin')
        test_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.set_output_fn(output_fn)

        generator_test = test_set.get_batch(self.config.batch_size, 1)
        accuracy = []
        for batch in generator_test:
            accuracy.append(batch)
            print(np.mean(accuracy))
コード例 #14
0
    def train(self):
        train_set = Dataloader(self.config, 'data/train_stories.csv')
        test_set = Dataloader(self.config,
                              'data/test_stories.csv',
                              testing_data=True)
        train_set.load_dataset('data/train.bin')
        train_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.load_dataset('data/test.bin')
        test_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.set_output_fn(output_fn_test)
        train_set.set_output_fn(output_fn)

        generator_training = train_set.get_batch(self.config.batch_size,
                                                 self.config.n_epochs)
        generator_dev = test_set.get_batch(self.config.batch_size,
                                           self.config.n_epochs)

        # Initialize tensorflow session
        sess = tf.Session()
        K.set_session(sess)  # Set to keras backend

        keras_model = self.build_graph(sess)

        verbose = 0 if not self.config.debug else 1
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        # Callbacks
        tensorboard = keras.callbacks.TensorBoard(
            log_dir='./logs/' + timestamp + '-reorder-full-rd-elmo/',
            histogram_freq=0,
            batch_size=self.config.batch_size,
            write_graph=False,
            write_grads=True)

        model_path = os.path.abspath(
            os.path.join(os.curdir, './builds/' + timestamp))
        model_path += '-reorder-full-rd-elmo_checkpoint_epoch-{epoch:02d}.hdf5'

        saver = keras.callbacks.ModelCheckpoint(model_path,
                                                monitor='val_loss',
                                                verbose=verbose,
                                                save_best_only=True)

        keras_model.fit_generator(generator_training,
                                  steps_per_epoch=300,
                                  epochs=self.config.n_epochs,
                                  verbose=verbose,
                                  validation_data=generator_dev,
                                  validation_steps=len(test_set) /
                                  self.config.batch_size,
                                  callbacks=[tensorboard, saver])
コード例 #15
0
def main(config):
    # Initialize tensorflow session
    sess = tf.Session()
    K.set_session(sess)  # Set to keras backend

    if config.debug:
        print('Importing Elmo module...')
    if config.hub.is_set("cache_dir"):
        os.environ['TFHUB_CACHE_DIR'] = config.hub.cache_dir

    elmo_model = hub.Module("https://tfhub.dev/google/elmo/1", trainable=True)
    if config.debug:
        print('Imported.')

    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())

    graph = tf.get_default_graph()

    elmo_emb_fn = ElmoEmbedding(elmo_model)

    elmo_model_emb = get_elmo_embedding(elmo_emb_fn)

    type_translation_model = keras.models.load_model(
        config.type_translation_model, {'elmo_embeddings': elmo_emb_fn})

    output_fn = OutputFN(elmo_model_emb, type_translation_model, graph)
    output_fn_test = OutputFNTest(elmo_model_emb, type_translation_model,
                                  graph)
    train_set = Dataloader(config, 'data/train_stories.csv')
    # test_set.set_preprocess_fn(preprocess_fn)
    train_set.load_dataset('data/train.bin')
    train_set.load_vocab('./data/default.voc', config.vocab_size)
    train_set.set_output_fn(output_fn)

    test_set = Dataloader(config, 'data/test_stories.csv', testing_data=True)
    # test_set.set_preprocess_fn(preprocess_fn)
    test_set.load_dataset('data/test.bin')
    test_set.load_vocab('./data/default.voc', config.vocab_size)
    test_set.set_output_fn(output_fn_test)

    generator_training = train_set.get_batch(config.batch_size,
                                             config.n_epochs)
    generator_test = test_set.get_batch(config.batch_size, config.n_epochs)

    # print(next(generator_training))

    keras_model = model()

    verbose = 0 if not config.debug else 1
    timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    # Callbacks
    tensorboard = keras.callbacks.TensorBoard(log_dir='./logs/' + timestamp +
                                              '-entailmentv4/',
                                              histogram_freq=0,
                                              batch_size=config.batch_size,
                                              write_graph=False,
                                              write_grads=True)

    model_path = os.path.abspath(
        os.path.join(os.curdir, './builds/' + timestamp))
    model_path += '-entailmentv4_checkpoint_epoch-{epoch:02d}.hdf5'

    saver = keras.callbacks.ModelCheckpoint(model_path,
                                            monitor='val_loss',
                                            verbose=verbose,
                                            save_best_only=True)

    keras_model.fit_generator(generator_training,
                              steps_per_epoch=5,
                              epochs=config.n_epochs,
                              verbose=verbose,
                              validation_data=generator_test,
                              validation_steps=5,
                              callbacks=[tensorboard, saver])
コード例 #16
0
    def test(self):
        testing_set = Dataloader(self.config, testing_data=True)
        testing_set.load_dataset('data/test.bin')

        testing_set.load_vocab('data/default.voc', self.config.vocab_size)
        test(self.config, testing_set)
コード例 #17
0
    def train(self):
        # Initialize tensorflow session
        sess = tf.Session()
        K.set_session(sess)  # Set to keras backend

        if self.config.debug:
            print('Importing Elmo module...')
        if self.config.hub.is_set("cache_dir"):
            os.environ['TFHUB_CACHE_DIR'] = self.config.hub.cache_dir

        elmo_model = hub.Module("https://tfhub.dev/google/elmo/1",
                                trainable=True)
        if self.config.debug:
            print('Imported.')

        # If we gave the models to the encoder decodes...
        self.use_pretrained_models = self.config.alignment.is_set(
            'decoder_target_model') and self.config.alignment.is_set(
                'decoder_src_model') and self.config.alignment.is_set(
                    'encoder_target_model') and self.config.alignment.is_set(
                        'encoder_src_model')

        sess.run(tf.global_variables_initializer())
        sess.run(tf.tables_initializer())

        self.graph = tf.get_default_graph()

        elmo_emb_fn = ElmoEmbedding(elmo_model)

        elmo_embeddings = keras.layers.Lambda(elmo_emb_fn,
                                              output_shape=(1024, ))
        sentence = keras.layers.Input(shape=(1, ), dtype="string")
        sentence_emb = elmo_embeddings(sentence)

        self.elmo_model = keras.models.Model(inputs=sentence,
                                             outputs=sentence_emb)

        train_set = SNLIDataloaderPairs('data/snli_1.0/snli_1.0_train.jsonl')
        train_set.set_preprocess_fn(preprocess_fn)
        train_set.load_vocab('./data/snli_vocab.dat', self.config.vocab_size)
        train_set.set_output_fn(self.output_fn)

        test_set = Dataloader(self.config,
                              'data/test_stories.csv',
                              testing_data=True)
        test_set.load_dataset('data/test.bin')
        test_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.set_output_fn(self.output_fn_test)

        generator_training = train_set.get_batch(self.config.batch_size,
                                                 self.config.n_epochs)
        generator_dev = test_set.get_batch(self.config.batch_size,
                                           self.config.n_epochs)

        self.define_models()

        model = self.build_graph()
        frozen_model = self.build_frozen_graph()

        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        writer = tf.summary.FileWriter('./logs/' + timestamp + '-alignment/',
                                       self.graph)

        model_path = os.path.abspath(
            os.path.join(os.curdir, './builds/' + timestamp))
        model_path += '-alignment-model_checkpoint_step-'

        last_created_file = None

        self.use_frozen = False
        min_source_loss = None

        for k, (inputs, labels) in enumerate(generator_training):
            # We train the frozen model and the unfrozen model jointly
            if self.use_frozen:
                # Generator training
                metrics = frozen_model.train_on_batch(inputs, labels)
                if not k % self.config.print_train_every:
                    print_on_tensorboard(writer, frozen_model.metrics_names,
                                         metrics, k, 'train_f')
            else:
                metrics = model.train_on_batch(inputs, labels)
                if not k % self.config.print_train_every:
                    print_on_tensorboard(writer, model.metrics_names, metrics,
                                         k, 'train_uf')

            self.use_frozen = not self.use_frozen

            if k > 0 and not k % self.config.test_and_save_every:
                test_metrics = []
                for j, (inputs_val, labels_val) in enumerate(generator_dev):
                    test_metrics.append(
                        frozen_model.test_on_batch(inputs_val, labels_val))
                test_metrics = np.mean(test_metrics, axis=0)
                # Save value to tensorboard
                print_on_tensorboard(writer, frozen_model.metrics_names,
                                     test_metrics, k, 'test')
                test_metrics_dict = get_dict_from_lists(
                    frozen_model.metrics_names, test_metrics)
                # We save the model is loss is better for generator
                # We only want to save the generator model
                if min_source_loss is None or test_metrics_dict[
                        'disrc_src_loss'] < min_source_loss:
                    frozen_model.save(model_path + str(k) + ".hdf5")
                    if last_created_file is not None:
                        os.remove(last_created_file)  # Only keep the best one
                    last_created_file = model_path + str(k) + ".hdf5"
コード例 #18
0
ファイル: train.py プロジェクト: luofeng1994/GeolifeData
def main():
    args = dict()
    args[
        'train_data_path'] = '../predict_training_corpus/clusterTrajectory_100m_5minutes_50eps_5minPts_uniq_3numseqs_train.txt'
    args[
        'test_data_path'] = '../predict_training_corpus/clusterTrajectory_100m_5minutes_50eps_5minPts_uniq_3numseqs_test.txt'
    args[
        'dbscanner_path'] = '../utils/dbscaner_100m_5minutes_50eps_5minPts.pkl'
    args['save_dir'] = './save_3numsteps'
    args['embedding_path'] = '../word2vec/variable/embedding.pkl'
    args['batch_size'] = 5
    args['lstm_size'] = 128
    args['lstm_layer'] = 1
    args['weight_decay'] = 0.00001
    args['is_training'] = True
    args['keep_prob'] = 0.5
    args['lr'] = 0.001
    args['epochs'] = 70
    args['save_every_n'] = 200

    dataloader = Dataloader(args)
    args['num_steps'] = dataloader.num_steps
    args['feature_dim'] = dataloader.feature_dim
    args['classes'] = dataloader.classes

    pickle.dump(args, open('./utils/args.pkl', 'wb'))

    model = Model(args)
    saver = tf.train.Saver(max_to_keep=10)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        count = 0
        for e in range(args['epochs']):
            dataloader.reset()
            # new_state = sess.run(model.initial_state)
            for i in range(dataloader.batch_num):
                count += 1
                x, y = dataloader.next_batch()
                start = time.time()
                feed = {model.input: x, model.target: y}
                batch_loss, new_state, _ = sess.run(
                    [model.loss, model.final_state, model.optimizer],
                    feed_dict=feed)
                end = time.time()
                if count % 100 == 0:
                    print('round_num: {}/{}... '.format(e + 1, args['epochs']),
                          'Training steps: {}... '.format(count),
                          'Training error: {:.4f}... '.format(batch_loss),
                          '{:.4f} sec/batch'.format((end - start)))

                if (count % args['save_every_n'] == 0):
                    # aaa = 1
                    saver.save(
                        sess, "{path}/i{counter}_l{lstm_size}.ckpt".format(
                            path=args['save_dir'],
                            counter=count,
                            lstm_size=args['lstm_size']))
        saver.save(
            sess, "{path}/i{counter}_l{lstm_size}.ckpt".format(
                path=args['save_dir'],
                counter=count,
                lstm_size=args['lstm_size']))
コード例 #19
0
    def train(self):
        output_fn = OutputFN(self.config.GLOVE_PATH, self.config.model_path)
        train_set = Dataloader(self.config, 'data/train_stories.csv')
        test_set = Dataloader(self.config, 'data/test_stories.csv', testing_data=True)
        train_set.set_special_tokens(["<unk>"])
        test_set.set_special_tokens(["<unk>"])
        train_set.load_dataset('data/train.bin')
        train_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.load_dataset('data/test.bin')
        test_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.set_output_fn(output_fn.output_fn_test)
        train_set.set_output_fn(output_fn)
        generator_training = train_set.get_batch(self.config.batch_size, 1)
        generator_dev = test_set.get_batch(self.config.batch_size, 1)
        epoch = 0
        max_acc = 0
        start = time.time()
        encoder_src=Encoder_src().cuda()
        decoder_src = Decoder_src().cuda()
        encoder_tgt = Encoder_tgt().cuda()
        decoder_tgt = Decoder_tgt().cuda()
        discriminator = Discriminator().cuda()
        encoder_optimizer_source = optim.Adam(encoder_src.parameters(), lr=self.config.learning_rate)
        decoder_optimizer_source = optim.Adam(decoder_src.parameters(),
                                                   lr=self.config.learning_rate)
        encoder_optimizer_target = optim.Adam(encoder_tgt.parameters(), lr=self.config.learning_rate)
        decoder_optimizer_target = optim.Adam(decoder_tgt.parameters(),
                                                   lr=self.config.learning_rate)
        discriminator_optmizer = optim.RMSprop(discriminator.parameters(), lr=self.config.learning_rate_discriminator)
        criterion_adver = nn.BCELoss()
        target_adver_src = Variable(torch.zeros(self.config.batch_size)).cuda()
        target_adver_tgt = Variable(torch.ones(self.config.batch_size)).cuda()
        plot_loss_total = 0
        plot_loss_total_adv = 0


        compteur=0
        compteur_val = 0
        while epoch < self.config.n_epochs:
            print("Epoch:", epoch)
            epoch += 1
            for num_1, batch in enumerate(generator_training):
                all_histoire_debut_embedding = Variable(torch.FloatTensor(batch[0])).cuda()
                all_histoire_fin_embedding = Variable(torch.FloatTensor(batch[1])).cuda()
                all_histoire_debut_noise =Variable(torch.FloatTensor(batch[2])).cuda()
                all_histoire_fin_noise = Variable(torch.FloatTensor(batch[3])).cuda()
                if num_1%2==0:
                    encoder_optimizer_source.zero_grad()
                    decoder_optimizer_source.zero_grad()
                    encoder_optimizer_target.zero_grad()
                    decoder_optimizer_target.zero_grad()
                    encoder_src.train(True)
                    decoder_src.train(True)
                    encoder_tgt.train(True)
                    decoder_tgt.train(True)
                    discriminator.train(False)
                    z_src_autoencoder=encoder_src(all_histoire_debut_noise)
                    out_src_auto = decoder_src(z_src_autoencoder)
                    loss1=torch.nn.functional.cosine_embedding_loss(out_src_auto.transpose(0,2).transpose(0,1),all_histoire_debut_noise.transpose(0,2).transpose(0,1),target_adver_tgt)
                    z_tgt_autoencoder = encoder_tgt(all_histoire_fin_noise)
                    out_tgt_auto = decoder_tgt(z_tgt_autoencoder)
                    loss2 = torch.nn.functional.cosine_embedding_loss(out_tgt_auto.transpose(0, 2).transpose(0,1),
                                                                      all_histoire_fin_embedding.transpose(0, 2).transpose(0,1),
                                                                      target_adver_tgt)
                    if epoch == 1:
                        y_src_eval= all_histoire_debut_noise
                        y_tgt_eval = all_histoire_fin_noise
                    else:
                        encoder_src.train(False)
                        decoder_src.train(False)
                        encoder_tgt.train(False)
                        decoder_tgt.train(False)
                        y_tgt_eval = decoder_tgt(encoder_src(all_histoire_debut_embedding))
                        y_src_eval = decoder_src(encoder_tgt(all_histoire_fin_embedding))
                        encoder_src.train(True)
                        decoder_src.train(True)
                        encoder_tgt.train(True)
                        decoder_tgt.train(True)

                    z_src_cross=encoder_src(y_src_eval)
                    pred_fin = decoder_tgt(z_src_cross)
                    loss3 = torch.nn.functional.cosine_embedding_loss(pred_fin.transpose(0, 2).transpose(0,1),
                                                                      all_histoire_fin_embedding.transpose(0, 2).transpose(0,1),
                                                                      target_adver_tgt)
                    # evaluate2
                    z_tgt_cross = encoder_tgt(y_tgt_eval)
                    pred_debut = decoder_src(z_tgt_cross)
                    loss4 = torch.nn.functional.cosine_embedding_loss(pred_debut.transpose(0, 2).transpose(0,1),
                                                                      all_histoire_debut_embedding.transpose(0, 2).transpose(0,1),
                                                                      target_adver_tgt)
                    total_loss=loss1+loss2+loss3+loss4

                    total_loss.backward()
                    encoder_optimizer_source.step()
                    decoder_optimizer_source.step()
                    encoder_optimizer_target.step()
                    decoder_optimizer_target.step()
                    accuracy_summary = tf.Summary()
                    main_loss_total=total_loss.item()
                    accuracy_summary.value.add(tag='train_loss_main', simple_value=main_loss_total)
                    writer.add_summary(accuracy_summary, num_1)
                    plot_loss_total += main_loss_total
                else:
                    #REDEFINIR lINPUT du DATASET
                    new_X=[]
                    new_Y_src=[]
                    new_Y_tgt=[]
                    for num_3,batch_3 in enumerate(all_histoire_debut_embedding):

                        if random.random()>0.5:

                            new_X.append(batch_3.cpu().numpy())
                            new_Y_src.append(1)
                            new_Y_tgt.append(0)
                        else:
                            new_X.append(all_histoire_fin_embedding[num_3].cpu().numpy())
                            new_Y_src.append(0)
                            new_Y_tgt.append(1)
                    all_histoire_debut_noise=Variable(torch.FloatTensor(np.array(new_X))).cuda()
                    target_adver_src=Variable(torch.FloatTensor(np.array(new_Y_src))).cuda()
                    target_adver_tgt=Variable(torch.FloatTensor(np.array(new_Y_tgt))).cuda()
                    discriminator_optmizer.zero_grad()
                    discriminator.train(True)
                    encoder_src.train(False)
                    decoder_src.train(False)
                    encoder_tgt.train(False)
                    decoder_tgt.train(False)
                    z_src_autoencoder=encoder_src(all_histoire_debut_noise)
                    pred_discriminator_src=discriminator.forward(z_src_autoencoder)
                    pred_discriminator_src = pred_discriminator_src.view(-1)
                    adv_loss1 = criterion_adver(pred_discriminator_src, target_adver_src)
                    z_tgt_autoencoder = encoder_tgt(all_histoire_debut_noise)
                    pred_discriminator_tgt = discriminator.forward(z_tgt_autoencoder)
                    pred_discriminator_tgt = pred_discriminator_tgt.view(-1)
                    adv_loss2 = criterion_adver(pred_discriminator_tgt, target_adver_tgt)
                    if epoch == 1:
                        y_src_eval= all_histoire_debut_noise
                        y_tgt_eval = all_histoire_fin_noise
                    else:
                        encoder_src.train(False)
                        decoder_src.train(False)
                        encoder_tgt.train(False)
                        decoder_tgt.train(False)
                        y_tgt_eval = decoder_tgt(encoder_src(all_histoire_debut_embedding))
                        y_src_eval = decoder_src(encoder_tgt(all_histoire_debut_embedding))
                        encoder_src.train(True)
                        decoder_src.train(True)
                        encoder_tgt.train(True)
                        decoder_tgt.train(True)
                        #evaluate1
                    z_src_cross=encoder_src(y_src_eval)
                    pred_discriminator_src = discriminator.forward(z_src_cross)
                    pred_discriminator_src = pred_discriminator_src.view(-1)
                    adv_loss3 = criterion_adver(pred_discriminator_src, target_adver_src)
                    # evaluate2
                    z_tgt_cross = encoder_tgt(y_tgt_eval)
                    pred_discriminator_tgt = discriminator.forward(z_tgt_cross)
                    pred_discriminator_tgt = pred_discriminator_tgt.view(-1)
                    adv_loss4 = criterion_adver(pred_discriminator_tgt, target_adver_tgt)
                    total_loss_adv=adv_loss1+adv_loss2+adv_loss3+adv_loss4
                    total_loss_adv.backward()
                    discriminator_optmizer.step()
                    accuracy_summary = tf.Summary()
                    main_loss_total_adv=total_loss_adv.item()
                    accuracy_summary.value.add(tag='train_loss_main_adv', simple_value=main_loss_total_adv)
                    writer.add_summary(accuracy_summary, num_1)
                    plot_loss_total_adv += main_loss_total_adv
                if num_1 % self.config.plot_every == self.config.plot_every - 1:
                    plot_loss_avg = plot_loss_total / self.config.plot_every
                    plot_loss_avg_adv = plot_loss_total_adv / self.config.plot_every
                    print_summary = '%s (%d %d%%) %.4f %.4f' % (
                    self.time_since(start, (num_1 + 1) / (90000 / 32)), (num_1 + 1),
                        (num_1 + 1) / (90000 / 32) * 100,
                        plot_loss_avg,plot_loss_avg_adv)
                    print(print_summary)
                    plot_loss_total = 0
                    compteur_val += 1
                    if compteur_val == 3:
                        compteur
                        compteur_val = 0
                        correct = 0
                        correctfin = 0
                        correctdebut = 0
                        total = 0
                        for num, batch in enumerate(generator_dev):
                            compteur+=1
                            encoder_src.train(False)
                            decoder_src.train(False)
                            encoder_tgt.train(False)
                            decoder_tgt.train(False)
                            discriminator.train(False)
                            if num < 11:
                                all_histoire_debut_embedding = Variable(torch.FloatTensor(batch[0]))
                                all_histoire_fin_embedding1 = Variable(torch.FloatTensor(batch[1]))
                                all_histoire_fin_embedding2 = Variable(torch.FloatTensor(batch[2]))
                                if USE_CUDA:
                                    all_histoire_debut_embedding=all_histoire_debut_embedding.cuda()
                                    all_histoire_fin_embedding1=all_histoire_fin_embedding1.cuda()
                                    all_histoire_fin_embedding2=all_histoire_fin_embedding2.cuda()
                                labels = Variable(torch.LongTensor(batch[3]))
                                end = decoder_tgt(encoder_src(all_histoire_debut_embedding))
                                z_end1 =encoder_src(all_histoire_fin_embedding1)
                                z_end2=encoder_src(all_histoire_fin_embedding2)


                                pred1 = discriminator.forward(z_end1)
                                pred1 = pred1.view(-1)
                                pred2 = discriminator.forward(z_end2)
                                pred2 = pred2.view(-1)

                                sim1 = torch.nn.functional.cosine_embedding_loss(end.transpose(0, 2).transpose(0,1),
                                                                                  all_histoire_fin_embedding1.transpose(
                                                                                      0, 2).transpose(0,1),
                                                                                  target_adver_tgt,reduce=False)
                                sim2 = torch.nn.functional.cosine_embedding_loss(end.transpose(0, 2).transpose(0,1),
                                                                                  all_histoire_fin_embedding1.transpose(
                                                                                      0, 2).transpose(0,1),
                                                                                  target_adver_tgt,reduce=False)
                                preds=(pred1<pred2).cpu().long()
                                preds_sim=(sim1>sim2).cpu().long()

                                correct += (preds == labels).sum().item()
                                correctdebut += (preds_sim == labels).sum().item()
                                total += self.config.batch_size
                                print("Accuracy ")
                                print(correct/ total,correctdebut/ total)
                                accuracy_summary = tf.Summary()
                                accuracy_summary.value.add(tag='val_accuracy',
                                                           simple_value=(correct / total))
                                accuracy_summary.value.add(tag='val_accuracy_similitude',
                                                           simple_value=(correctfin / total))
                                writer.add_summary(accuracy_summary,compteur)
                                if num % self.config.plot_every_test == self.config.plot_every_test - 1:
                                    plot_acc_avg = correct / total
                                    if plot_acc_avg > max_acc:
                                        torch.save(encoder_src.state_dict(),
                                                   './builds/encoder_source_best.pth')
                                        torch.save(encoder_tgt.state_dict(),
                                                   './builds/encoder_target_best.pth')
                                        torch.save(decoder_src.state_dict(),
                                                   './builds/decoder_source_best.pth')
                                        torch.save( decoder_tgt.state_dict(),
                                                   './builds/decoder_target_best.pth')
                                        max_acc = plot_acc_avg
                                        print('SAVE MODEL FOR ACCURACY : ' + str(plot_acc_avg))
                                    correct = 0
                                    correctfin=0
                                    correctdebut=0
                                    total = 0
                            else:
                                print('done validation')
                                encoder_src.train(True)
                                decoder_src.train(True)
                                encoder_tgt.train(True)
                                decoder_tgt.train(True)
                                break

            print('SAVE MODEL END EPOCH')
            torch.save(encoder_src.state_dict(), './builds/encoder_source_epoch' + str(epoch) + '.pth')
            torch.save(encoder_tgt.state_dict(), './builds/encoder_target_epoch' + str(epoch) + '.pth')
            torch.save(decoder_src.state_dict(), './builds/decoder_source_epoch' + str(epoch) + '.pth')
            torch.save(decoder_tgt.state_dict(), './builds/decoder_target_epoch' + str(epoch) + '.pth')
コード例 #20
0
def train(epochs=100, batch_size=1):
    #生成器
    #    img_shape = (256, 256, 3)
    netG = CycleGAN()
    netG_XY, real_X, fake_Y = netG.generator()
    netG_YX, real_Y, fake_X = netG.generator()

    reconstruct_X = netG_YX(fake_Y)
    reconstruct_Y = netG_XY(fake_X)
    #鉴别器
    netD = CycleGAN()
    netD_X = netD.discriminator()
    netD_Y = netD.discriminator()

    netD_X_predict_fake = netD_X(fake_X)
    netD_Y_predict_fake = netD_Y(fake_Y)
    netD_X_predict_real = netD_X(real_X)
    netD_Y_predict_real = netD_Y(real_Y)
    #    netD_X.summary()
    #优化器
    optimizer = Adam(lr=0.001,
                     beta_1=0.5,
                     beta_2=0.999,
                     epsilon=None,
                     decay=0.01)
    #    netG_XY.summary()
    #    plot_model(netG_XY, to_file='./netG_XY_model_graph.png')
    #GAN
    netD_X.trainable = False  #冻结
    netD_Y.trainable = False
    netG_loss_inputs = [
        netD_X_predict_fake, reconstruct_X, real_X, netD_Y_predict_fake,
        reconstruct_Y, real_Y
    ]
    netG_train = Model([real_X, real_Y], Lambda(netG_loss)(netG_loss_inputs))
    netG_train.compile(loss='mae', optimizer=optimizer, metrics=['accuracy'])

    _fake_X_inputs = Input(shape=(256, 256, 3))
    _fake_Y_inputs = Input(shape=(256, 256, 3))
    _netD_X_predict_fake = netD_X(_fake_X_inputs)
    _netD_Y_predict_fake = netD_Y(_fake_Y_inputs)
    netD_X.trainable = True
    netD_X_train = Model(
        [real_X, _fake_X_inputs],
        Lambda(netD_loss)([netD_X_predict_real, _netD_X_predict_fake]))
    netD_X_train.compile(loss='mae', optimizer=optimizer,
                         metrics=['accuracy'])  #均方误差

    netD_X.trainable = False
    netD_Y.trainable = True
    netD_Y_train = Model(
        [real_Y, _fake_Y_inputs],
        Lambda(netD_loss)([netD_Y_predict_real, _netD_Y_predict_fake]))
    netD_Y_train.compile(loss='mae', optimizer=optimizer, metrics=['accuracy'])

    dataloader = Dataloader()
    fake_X_pool = ImagePool()
    fake_Y_pool = ImagePool()

    netG_X_function = get_G_function(netG_XY)
    netG_Y_function = get_G_function(netG_YX)
    if len(os.listdir('./weights')):
        netG_train.load_weights('./weights/netG.h5')
        netD_X_train.load_weights('./weights/netD_X.h5')
        netD_Y_train.load_weights

    print('Info: Strat Training\n')
    for epoch in range(epochs):

        target_label = np.zeros((batch_size, 1))

        for batch_i, (imgs_X,
                      imgs_Y) in enumerate(dataloader.load_batch(batch_size)):
            start_time = time.time()
            num_batch = 0
            tmp_fake_X = netG_X_function([imgs_X])[0]
            tmp_fake_Y = netG_Y_function([imgs_Y])[0]

            #从缓存区读取图片
            _fake_X = fake_X_pool.action(tmp_fake_X)
            _fake_Y = fake_Y_pool.action(tmp_fake_Y)
            if batch_i % 2 == 0:
                save_image('fake_X_' + str(epoch) + '_' + str(batch_i),
                           _fake_X[0])
                save_image('fake_Y_' + str(epoch) + '_' + str(batch_i),
                           _fake_Y[0])
            _netG_loss = netG_train.train_on_batch([imgs_X, imgs_Y],
                                                   target_label)
            netD_X_loss = netD_X_train.train_on_batch([imgs_X, _fake_X],
                                                      target_label)
            netD_Y_loss = netD_Y_train.train_on_batch([imgs_Y, _fake_Y],
                                                      target_label)
            num_batch += 1
            diff = time.time() - start_time
            print('Epoch:{}/{},netG_loss:{}, netD_loss:{},{}, time_cost_per_epoch:{}/epoch'\
              .format(epoch+1, epochs, _netG_loss, netD_X_loss, netD_Y_loss, diff, diff/num_batch))

        netG_train.save_weights('./weights/netG.h5')
        netD_X_train.save_weights('./weights/netD_X.h5')
        netD_Y_train.save_weights('./weights/netD_Y.hs')
        print('Model saved!\n')
    pass
コード例 #21
0
ファイル: main.py プロジェクト: endiqq/Fus-CNNs_COVID-19
                print('-' * 30)
                for k in range(K_fold):
                    print('K_fold = %d' % k)

                    if args.action == 'train':
                        # tensorboard
                        writer = SummaryWriter('runs/' + args.network + '_' +
                                               Dataset + '_' + str(k))
                        model, size, pretrained, num_ftrs = tf_learning(
                        ).def_model(args.network,
                                    num_classes,
                                    ff,
                                    use_pretrained=True)
                        model_conv = model.to(device)

                        dataloaders = Dataloader(Dataset).data_loader(
                            size, k, batch_size)

                        params_to_update = []
                        print("Params to learn:")
                        for name, param in model_conv.named_parameters():
                            if param.requires_grad == True:
                                params_to_update.append(param)
                                print("\t", name)

                        _, hist, best_acc, Best_model_wts, Last_model_wts = Trainer_models.train_model(
                            model_conv, dataloaders, nepoch, params_to_update,
                            writer)
                        # save the last validation model weights and the best validation model weights
                        torch.save(Best_model_wts,
                                   'Aug_Best_' + Dataset + '_' + ff + '_' +
                                   args.network + '_' + Dataset + '_k_' +
コード例 #22
0
    def train(self):
        training_set = Dataloader(self.config)
        training_set.load_dataset('./data/train.bin')
        training_set.load_vocab('./data/default.voc', self.config.vocab_size)

        testing_set = Dataloader(self.config, testing_data=True)
        testing_set.load_dataset('data/test.bin')
        testing_set.load_vocab('./data/default.voc', self.config.vocab_size)

        main(self.config, training_set, testing_set)
コード例 #23
0
    def train(self):
        output_fn = OutputFN(self.config.GLOVE_PATH, self.config.model_path)
        train_set = Dataloader(self.config, 'data/train_stories.csv')
        test_set = Dataloader(self.config,
                              'data/test_stories.csv',
                              testing_data=True)
        train_set.set_special_tokens(["<unk>"])
        test_set.set_special_tokens(["<unk>"])
        train_set.load_dataset('data/train.bin')
        train_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.load_dataset('data/test.bin')
        test_set.load_vocab('./data/default.voc', self.config.vocab_size)
        test_set.set_output_fn(output_fn.output_fn_test)
        train_set.set_output_fn(output_fn)
        generator_training = train_set.get_batch(self.config.batch_size, 1)
        generator_dev = test_set.get_batch(self.config.batch_size, 1)
        epoch = 0
        max_acc = 0
        plot_losses_train = []
        # plot_losses_train_adv=[]
        plot_losses_train_cross = []
        plot_losses_train_auto = []
        plot_accurracies_avg = []
        plot_accurracies_avg_val = []
        start = time.time()
        Seq2SEq_main_model = Seq2SeqTrainer(self.config.hidden_size,
                                            self.config.embedding_size,
                                            self.config.n_layers,
                                            self.config.batch_size,
                                            self.config.attention_bolean,
                                            dropout=0.5,
                                            learning_rate=0.0003,
                                            plot_every=20,
                                            print_every=100,
                                            evaluate_every=1000)
        timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        writer = tf.summary.FileWriter('./logs/' + timestamp + '-concept-fb/')
        plot_loss_total = 0
        plot_loss_total_auto = 0
        plot_loss_total_cross = 0
        compteur_val = 0
        while epoch < self.config.n_epochs:
            print("Epoch:", epoch)
            epoch += 1
            for phase in ['train', 'test']:
                print(phase)
                if phase == 'train':
                    for num_1, batch in enumerate(generator_training):
                        main_loss_total, loss_auto_debut, loss_auto_fin, loss_cross_debut, loss_cross_fin = Seq2SEq_main_model.train_all(
                            batch)
                        accuracy_summary = tf.Summary()
                        accuracy_summary.value.add(
                            tag='train_loss_main',
                            simple_value=main_loss_total)
                        accuracy_summary.value.add(
                            tag='train_loss_auto',
                            simple_value=loss_auto_debut + loss_auto_fin)
                        accuracy_summary.value.add(
                            tag='train_loss_cross',
                            simple_value=loss_cross_debut + loss_cross_fin)
                        writer.add_summary(accuracy_summary, num_1)
                        plot_loss_total += main_loss_total
                        plot_loss_total_auto += loss_auto_debut + loss_auto_fin
                        plot_loss_total_cross += loss_cross_debut + loss_cross_fin
                        if num_1 % self.config.plot_every == self.config.plot_every - 1:
                            plot_loss_avg = plot_loss_total / self.config.plot_every
                            plot_loss_auto_avg = plot_loss_total_auto / self.config.plot_every
                            plot_loss_cross_avg = plot_loss_total_cross / self.config.plot_every
                            plot_losses_train.append(plot_loss_avg)
                            # plot_losses_train_adv.append(plot_loss_adv_avg)
                            plot_losses_train_auto.append(plot_loss_auto_avg)
                            plot_losses_train_cross.append(plot_loss_cross_avg)
                            #np.save('./builds/main_loss', np.array(plot_losses_train))
                            #np.save('./builds/adv_loss', np.array(plot_losses_train_adv))
                            #np.save('./builds/auto_loss', np.array(plot_losses_train_auto))
                            #np.save('./builds/cross_loss', np.array(plot_losses_train_cross))
                            print_summary = '%s (%d %d%%) %.4f %.4f %.4f' % (
                                Seq2SEq_main_model.time_since(
                                    start, (num_1 + 1) / (90000 / 32)),
                                (num_1 + 1), (num_1 + 1) /
                                (90000 / 32) * 100, plot_loss_avg,
                                plot_loss_auto_avg, plot_loss_cross_avg)
                            print(print_summary)
                            plot_loss_total = 0
                            plot_loss_total_auto = 0
                            plot_loss_total_cross = 0
                            compteur_val += 1
                            if compteur_val == 3:
                                compteur_val = 0
                                correct = 0
                                correctfin = 0
                                correctdebut = 0
                                dcorrect = 0
                                dcorrectfin = 0
                                dcorrectdebut = 0
                                total = 0
                                for num, batch in enumerate(generator_dev):
                                    if num < 21:
                                        all_histoire_debut_embedding = Variable(
                                            torch.FloatTensor(
                                                batch[0])).transpose(0, 1)
                                        all_histoire_fin_embedding1 = Variable(
                                            torch.FloatTensor(
                                                batch[1])).transpose(0, 1)
                                        all_histoire_fin_embedding2 = Variable(
                                            torch.FloatTensor(
                                                batch[2])).transpose(0, 1)
                                        if USE_CUDA:
                                            all_histoire_debut_embedding = all_histoire_debut_embedding.cuda(
                                            )
                                            all_histoire_fin_embedding1 = all_histoire_fin_embedding1.cuda(
                                            )
                                            all_histoire_fin_embedding2 = all_histoire_fin_embedding2.cuda(
                                            )
                                        labels = Variable(
                                            torch.LongTensor(batch[3]))
                                        end = Seq2SEq_main_model.evaluate(
                                            Seq2SEq_main_model.encoder_source,
                                            Seq2SEq_main_model.decoder_target,
                                            all_histoire_debut_embedding,
                                            Seq2SEq_main_model.
                                            input_length_debut)
                                        debut1 = Seq2SEq_main_model.evaluate(
                                            Seq2SEq_main_model.encoder_source,
                                            Seq2SEq_main_model.decoder_target,
                                            all_histoire_fin_embedding1,
                                            Seq2SEq_main_model.input_length_fin
                                        )
                                        debut2 = Seq2SEq_main_model.evaluate(
                                            Seq2SEq_main_model.encoder_source,
                                            Seq2SEq_main_model.decoder_target,
                                            all_histoire_fin_embedding2,
                                            Seq2SEq_main_model.input_length_fin
                                        )
                                        preds, predfin, preddebut, preds_dist, predfin_dis, preddebut_dis = self.get_predict(
                                            end, debut1, debut2,
                                            all_histoire_debut_embedding.
                                            transpose(0, 1),
                                            all_histoire_fin_embedding1.
                                            transpose(0, 1),
                                            all_histoire_fin_embedding2.
                                            transpose(0, 1))

                                        preds = preds.cpu().long()
                                        predfin = predfin.cpu().long()
                                        preddebut = preddebut.cpu().long()

                                        correct += (
                                            preds == labels).sum().item()
                                        correctfin += (
                                            predfin == labels).sum().item()
                                        correctdebut += (
                                            preddebut == labels).sum().item()

                                        preds_dist = preds_dist.cpu().long()
                                        predfin_dis = predfin_dis.cpu().long()
                                        preddebut_dis = preddebut_dis.cpu(
                                        ).long()

                                        dcorrect += (
                                            preds_dist == labels).sum().item()
                                        dcorrectfin += (predfin_dis == labels
                                                        ).sum().item()
                                        dcorrectdebut += (preddebut_dis ==
                                                          labels).sum().item()

                                        total += self.config.batch_size

                                        print(
                                            "Accuracy colinéaire somme, fin, debut"
                                        )
                                        print(correct / total,
                                              correctfin / total,
                                              correctdebut / total)
                                        print(
                                            "Accuracy distance somme, fin, debut"
                                        )
                                        print(dcorrect / total,
                                              dcorrectfin / total,
                                              dcorrectdebut / total)

                                        accuracy_summary = tf.Summary()
                                        accuracy_summary.value.add(
                                            tag='val_accuracy',
                                            simple_value=(correct / total))
                                        accuracy_summary.value.add(
                                            tag='val_accuracy_fin',
                                            simple_value=(correctfin / total))
                                        accuracy_summary.value.add(
                                            tag='val_accuracy_debut',
                                            simple_value=(correctdebut /
                                                          total))
                                        accuracy_summary.value.add(
                                            tag='val_accuracy_dist',
                                            simple_value=(dcorrect / total))
                                        accuracy_summary.value.add(
                                            tag='val_accuracy_fin_dist',
                                            simple_value=(dcorrectfin / total))
                                        accuracy_summary.value.add(
                                            tag='val_accuracy_debut_dist',
                                            simple_value=(dcorrectdebut /
                                                          total))
                                        writer.add_summary(
                                            accuracy_summary, num + num_1 - 1)
                                        if num % self.config.plot_every_test == self.config.plot_every_test - 1:
                                            plot_acc_avg = correct / total
                                            plot_accurracies_avg_val.append(
                                                plot_acc_avg)
                                            if plot_acc_avg > max_acc:
                                                torch.save(
                                                    Seq2SEq_main_model.
                                                    encoder_source.state_dict(
                                                    ),
                                                    './builds/encoder_source_best.pth'
                                                )
                                                torch.save(
                                                    Seq2SEq_main_model.
                                                    encoder_target.state_dict(
                                                    ),
                                                    './builds/encoder_target_best.pth'
                                                )
                                                torch.save(
                                                    Seq2SEq_main_model.
                                                    decoder_source.state_dict(
                                                    ),
                                                    './builds/decoder_source_best.pth'
                                                )
                                                torch.save(
                                                    Seq2SEq_main_model.
                                                    decoder_target.state_dict(
                                                    ),
                                                    './builds/decoder_target_best.pth'
                                                )
                                                max_acc = plot_acc_avg
                                                print(
                                                    'SAVE MODEL FOR ACCURACY : '
                                                    + str(plot_acc_avg))
                                            correct = 0
                                            correctfin = 0
                                            correctdebut = 0
                                            dcorrect = 0
                                            dcorrectfin = 0
                                            dcorrectdebut = 0
                                            total = 0
                                    else:
                                        print('done validation')
                                        break
                else:
                    print(phase)
                    correct = 0
                    correctfin = 0
                    correctdebut = 0
                    dcorrect = 0
                    dcorrectfin = 0
                    dcorrectdebut = 0
                    total = 0
                    for num, batch in enumerate(generator_dev):
                        all_histoire_debut_embedding = Variable(
                            torch.FloatTensor(batch[0])).transpose(0, 1)
                        all_histoire_fin_embedding1 = Variable(
                            torch.FloatTensor(batch[1])).transpose(0, 1)
                        all_histoire_fin_embedding2 = Variable(
                            torch.FloatTensor(batch[2])).transpose(0, 1)
                        if USE_CUDA:
                            all_histoire_debut_embedding = all_histoire_debut_embedding.cuda(
                            )
                            all_histoire_fin_embedding1 = all_histoire_fin_embedding1.cuda(
                            )
                            all_histoire_fin_embedding2 = all_histoire_fin_embedding2.cuda(
                            )
                        labels = Variable(torch.LongTensor(batch[3]))
                        end = Seq2SEq_main_model.evaluate(
                            Seq2SEq_main_model.encoder_source,
                            Seq2SEq_main_model.decoder_target,
                            all_histoire_debut_embedding,
                            Seq2SEq_main_model.input_length_debut)
                        debut1 = Seq2SEq_main_model.evaluate(
                            Seq2SEq_main_model.encoder_source,
                            Seq2SEq_main_model.decoder_target,
                            all_histoire_fin_embedding1,
                            Seq2SEq_main_model.input_length_fin)
                        debut2 = Seq2SEq_main_model.evaluate(
                            Seq2SEq_main_model.encoder_source,
                            Seq2SEq_main_model.decoder_target,
                            all_histoire_fin_embedding2,
                            Seq2SEq_main_model.input_length_fin)
                        preds, predfin, preddebut, preds_dist, predfin_dis, preddebut_dis = self.get_predict(
                            end, debut1, debut2,
                            all_histoire_debut_embedding.transpose(0, 1),
                            all_histoire_fin_embedding1.transpose(0, 1),
                            all_histoire_fin_embedding2.transpose(0, 1))

                        preds = preds.cpu().long()
                        predfin = predfin.cpu().long()
                        preddebut = preddebut.cpu().long()

                        correct += (preds == labels).sum().item()
                        correctfin += (predfin == labels).sum().item()
                        correctdebut += (preddebut == labels).sum().item()

                        preds_dist = preds_dist.cpu().long()
                        predfin_dis = predfin_dis.cpu().long()
                        preddebut_dis = preddebut_dis.cpu().long()

                        dcorrect += (preds_dist == labels).sum().item()
                        dcorrectfin += (predfin_dis == labels).sum().item()
                        dcorrectdebut += (preddebut_dis == labels).sum().item()

                        total += self.config.batch_size
                        accuracy_summary = tf.Summary()
                        accuracy_summary.value.add(tag='test_accuracy',
                                                   simple_value=(correct /
                                                                 total))
                        accuracy_summary.value.add(tag='test_accuracy_fin',
                                                   simple_value=(correctfin /
                                                                 total))
                        accuracy_summary.value.add(tag='test_accuracy_debut',
                                                   simple_value=(correctdebut /
                                                                 total))
                        accuracy_summary.value.add(tag='test_accuracy_dist',
                                                   simple_value=(dcorrect /
                                                                 total))
                        accuracy_summary.value.add(
                            tag='test_accuracy_fin_dist',
                            simple_value=(dcorrectfin / total))
                        accuracy_summary.value.add(
                            tag='test_accuracy_debut_dist',
                            simple_value=(dcorrectdebut / total))
                        writer.add_summary(accuracy_summary, num - 1)
                        if num % self.config.plot_every_test == self.config.plot_every_test - 1:
                            plot_acc_avg = correct / total
                            plot_accurracies_avg.append(plot_acc_avg)
                            #np.save('./builds/accuracy_test', np.array(plot_accurracies_avg))
                            correct = 0
                            correctfin = 0
                            correctdebut = 0
                            dcorrect = 0
                            dcorrectfin = 0
                            dcorrectdebut = 0
                            total = 0

                print('SAVE MODEL END EPOCH')
                torch.save(
                    Seq2SEq_main_model.encoder_source.state_dict(),
                    './builds/encoder_source_epoch' + str(epoch) + '.pth')
                torch.save(
                    Seq2SEq_main_model.encoder_target.state_dict(),
                    './builds/encoder_target_epoch' + str(epoch) + '.pth')
                torch.save(
                    Seq2SEq_main_model.decoder_source.state_dict(),
                    './builds/decoder_source_epoch' + str(epoch) + '.pth')
                torch.save(
                    Seq2SEq_main_model.decoder_target.state_dict(),
                    './builds/decoder_target_epoch' + str(epoch) + '.pth')