import numpy as np
from tensorflow import keras
from tensorflow.keras import Sequential
from tensorflow.keras import layers
from data_generator import training_generator, test_generator
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from encoder_decoder import EncoderDecoder
from model_wrapper import ModelWrapper, build_config

CHARS = '0123456789+ '
enc_dec = EncoderDecoder(CHARS)


def encode_generator(generator, batch_size):
    return map(enc_dec.encode, generator(batch_size=batch_size))


MODEL_NAME = 'number_addition'
LEARNING_RATE = 'default'
BATCH_SIZE = 512
STEPS_PER_EPOCH = 500
EPOCHS = 10
HIDDEN_SIZE = 256
RNN = layers.LSTM

callbacks = [
    ModelCheckpoint(filepath='checkpoint.h5',
                    verbose=1, save_best_only=True),
    EarlyStopping(monitor='val_loss', min_delta=0,
                  patience=10, verbose=2, mode='auto'),
    TensorBoard(log_dir='logs', histogram_freq=0,
示例#2
0
def train_dcc(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
              metric_name, seed, **kwargs):
    """
    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    train_history = defaultdict(list)
    # ae_eval_train_history = defaultdict(list)
    val_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    model_optimizer = torch.optim.AdamW(target_regressor.parameters(),
                                        lr=kwargs['lr'])
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'Coral training epoch {epoch}')
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = dcc_train_step(model=target_regressor,
                                           s_batch=s_batch,
                                           t_batch=t_batch,
                                           device=kwargs['device'],
                                           optimizer=model_optimizer,
                                           alpha=kwargs['alpha'],
                                           history=train_history)
        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'dcc_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'dcc_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])
    return target_regressor, (train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)
示例#3
0
def train_cleita(dataloader, seed, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(
        torch.load(os.path.join('./model_save', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    confounding_classifier = MLP(input_dim=kwargs['latent_dim'],
                                 output_dim=1,
                                 hidden_dims=kwargs['classifier_hidden_dims'],
                                 dop=kwargs['dop']).to(kwargs['device'])

    ae_train_history = defaultdict(list)
    ae_val_history = defaultdict(list)
    critic_train_history = defaultdict(list)
    gen_train_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [autoencoder.parameters(), transmitter.parameters()]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params),
                                            lr=kwargs['lr'])
        classifier_optimizer = torch.optim.RMSprop(
            confounding_classifier.parameters(), lr=kwargs['lr'])
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'confounder wgan training epoch {epoch}')
            for step, batch in enumerate(dataloader):
                critic_train_history = critic_train_step(
                    critic=confounding_classifier,
                    ae=autoencoder,
                    reference_encoder=reference_encoder,
                    transmitter=transmitter,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=classifier_optimizer,
                    history=critic_train_history,
                    # clip=0.1,
                    gp=10.0)
                if (step + 1) % 5 == 0:
                    gen_train_history = gan_gen_train_step(
                        critic=confounding_classifier,
                        ae=autoencoder,
                        transmitter=transmitter,
                        batch=batch,
                        device=kwargs['device'],
                        optimizer=cleit_optimizer,
                        alpha=1.0,
                        history=gen_train_history)

        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'],
                                 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_train_history, ae_val_history, critic_train_history,
                     gen_train_history)
示例#4
0
文件: mlp_main.py 项目: he-org/CLEIT
def fine_tune_encoder(train_dataloader, val_dataloader, seed, test_dataloader=None,
                      metric_name='cpearsonr',
                      normalize_flag=False, **kwargs):
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder,
                                      normalize_flag=normalize_flag).to(kwargs['device'])

    target_regression_train_history = defaultdict(list)
    target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    target_regression_optimizer = torch.optim.AdamW(target_regressor.parameters(), lr=kwargs['lr'])

    for epoch in range(kwargs['train_num_epochs']):
        if epoch % 10 == 0:
            print(f'MLP fine-tuning epoch {epoch}')
        for step, batch in enumerate(train_dataloader):
            target_regression_train_history = regression_train_step(model=target_regressor,
                                                                    batch=batch,
                                                                    device=kwargs['device'],
                                                                    optimizer=target_regression_optimizer,
                                                                    history=target_regression_train_history)
        target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                dataloader=train_dataloader,
                                                                                device=kwargs['device'],
                                                                                history=target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)

        if test_dataloader is not None:
            target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                   dataloader=test_dataloader,
                                                                                   device=kwargs['device'],
                                                                                   history=target_regression_eval_test_history)
        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=50)
        if save_flag or epoch == 0:
            torch.save(target_regressor.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt'))
            torch.save(target_regressor.encoder.state_dict(),
                       os.path.join(kwargs['model_save_folder'], f'ft_encoder_{seed}.pt'))
        if stop_flag:
            break

    target_regressor.load_state_dict(
        torch.load(os.path.join(kwargs['model_save_folder'], f'target_regressor_{seed}.pt')))

    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=val_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     cv_flag=True,
                                     output_folder=kwargs['model_save_folder'])
    if test_dataloader is not None:
        evaluate_target_regression_epoch(regressor=target_regressor,
                                         dataloader=test_dataloader,
                                         device=kwargs['device'],
                                         history=None,
                                         seed=seed,
                                         output_folder=kwargs['model_save_folder'])


    return target_regressor, (target_regression_train_history, target_regression_eval_train_history,
                              target_regression_eval_val_history, target_regression_eval_test_history)
示例#5
0
def train_dann(s_dataloaders, t_dataloaders, val_dataloader, test_dataloader,
               metric_name, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    s_train_dataloader = s_dataloaders
    t_train_dataloader = t_dataloaders

    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])
    encoder = autoencoder.encoder

    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(
        encoder=encoder, decoder=target_decoder).to(kwargs['device'])

    classifier = MLP(input_dim=kwargs['latent_dim'],
                     output_dim=1,
                     hidden_dims=kwargs['classifier_hidden_dims'],
                     dop=kwargs['dop'],
                     out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    confounder_classifier = EncoderDecoder(encoder=autoencoder.encoder,
                                           decoder=classifier).to(
                                               kwargs['device'])

    train_history = defaultdict(list)
    s_target_regression_eval_train_history = defaultdict(list)
    t_target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    confounded_loss = nn.BCEWithLogitsLoss()
    dann_params = [
        target_regressor.parameters(),
        confounder_classifier.decoder.parameters()
    ]
    dann_optimizer = torch.optim.AdamW(chain(*dann_params), lr=kwargs['lr'])

    # start alternative training
    for epoch in range(int(kwargs['train_num_epochs'])):
        if epoch % 50 == 0:
            print(f'DANN training epoch {epoch}')
        # start autoencoder training epoch
        for step, s_batch in enumerate(s_train_dataloader):
            t_batch = next(iter(t_train_dataloader))
            train_history = dann_train_step(classifier=confounder_classifier,
                                            model=target_regressor,
                                            s_batch=s_batch,
                                            t_batch=t_batch,
                                            loss_fn=confounded_loss,
                                            alpha=kwargs['alpha'],
                                            device=kwargs['device'],
                                            optimizer=dann_optimizer,
                                            history=train_history,
                                            scheduler=None)

        s_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=s_train_dataloader,
            device=kwargs['device'],
            history=s_target_regression_eval_train_history)

        t_target_regression_eval_train_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=t_train_dataloader,
            device=kwargs['device'],
            history=t_target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=val_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_val_history)
        target_regression_eval_test_history = evaluate_target_regression_epoch(
            regressor=target_regressor,
            dataloader=test_dataloader,
            device=kwargs['device'],
            history=target_regression_eval_test_history)

        save_flag, stop_flag = model_save_check(
            history=target_regression_eval_val_history,
            metric_name=metric_name,
            tolerance_count=50)
        if save_flag:
            torch.save(
                target_regressor.state_dict(),
                os.path.join(kwargs['model_save_folder'],
                             f'dann_regressor_{seed}.pt'))
        if stop_flag:
            break
    target_regressor.load_state_dict(
        torch.load(
            os.path.join(kwargs['model_save_folder'],
                         f'dann_regressor_{seed}.pt')))

    # evaluate_target_regression_epoch(regressor=target_regressor,
    #                                  dataloader=val_dataloader,
    #                                  device=kwargs['device'],
    #                                  history=None,
    #                                  seed=seed,
    #                                  output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (train_history,
                              s_target_regression_eval_train_history,
                              t_target_regression_eval_train_history,
                              target_regression_eval_val_history,
                              target_regression_eval_test_history)
from encoder_decoder import EncoderDecoder
import tools

encoder_decoder = EncoderDecoder(embedding_size=300,
                                 n_hidden_RNN=1024,
                                 do_train=True)
"""
# Training
data_provider = tools.DataProvider(path_to_csv='dataset/set_for_encoding.csv',
    # path_to_w2v='embeddings/GoogleNews-vectors-negative300.bin',
    path_to_w2v='~/GoogleNews-vectors-negative300.bin',    
    test_size=0.15, path_to_vocab='dataset/vocab.pickle')

encoder_decoder.train_(data_loader=data_provider, keep_prob=1, weight_decay=0.005,
    learn_rate_start=0.005, learn_rate_end=0.0003, batch_size=64, n_iter=100000,
    save_model_every_n_iter=5000, path_to_model='models/siamese')



#Evaluating COST
encoder_decoder.eval_cost(data_loader=data_provider, batch_size=512,
    path_to_model='models/siamese')
"""

#Prediction
data_provider = tools.DataProvider(
    path_to_csv='dataset/little.csv',
    path_to_w2v='embeddings/GoogleNews-vectors-negative300.bin',
    # path_to_w2v='~/GoogleNews-vectors-negative300.bin',
    test_size=0,
    path_to_vocab='dataset/vocab.pickle')
示例#7
0
def fine_tune_encoder(encoder, train_dataloader, val_dataloader, seed, task_save_folder, test_dataloader=None,
                      metric_name='cpearsonr',
                      normalize_flag=False, **kwargs):
    target_decoder = MoMLP(input_dim=kwargs['latent_dim'],
                           output_dim=kwargs['output_dim'],
                           hidden_dims=kwargs['regressor_hidden_dims'],
                           out_fn=torch.nn.Sigmoid).to(kwargs['device'])

    target_regressor = EncoderDecoder(encoder=encoder,
                                      decoder=target_decoder,
                                      normalize_flag=normalize_flag).to(kwargs['device'])

    target_regression_train_history = defaultdict(list)
    target_regression_eval_train_history = defaultdict(list)
    target_regression_eval_val_history = defaultdict(list)
    target_regression_eval_test_history = defaultdict(list)

    encoder_module_indices = [i for i in range(len(list(encoder.modules())))
                              if str(list(encoder.modules())[i]).startswith('Linear')]

    reset_count = 1
    lr = kwargs['lr']

    target_regression_params = [target_regressor.decoder.parameters()]
    target_regression_optimizer = torch.optim.AdamW(chain(*target_regression_params),
                                                    lr=lr)

    for epoch in range(kwargs['train_num_epochs']):
        if epoch % 50 == 0:
            print(f'Fine tuning epoch {epoch}')
        for step, batch in enumerate(train_dataloader):
            target_regression_train_history = regression_train_step(model=target_regressor,
                                                                    batch=batch,
                                                                    device=kwargs['device'],
                                                                    optimizer=target_regression_optimizer,
                                                                    history=target_regression_train_history)
        target_regression_eval_train_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                dataloader=train_dataloader,
                                                                                device=kwargs['device'],
                                                                                history=target_regression_eval_train_history)
        target_regression_eval_val_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                              dataloader=val_dataloader,
                                                                              device=kwargs['device'],
                                                                              history=target_regression_eval_val_history)

        if test_dataloader is not None:
            target_regression_eval_test_history = evaluate_target_regression_epoch(regressor=target_regressor,
                                                                                   dataloader=test_dataloader,
                                                                                   device=kwargs['device'],
                                                                                   history=target_regression_eval_test_history)
        save_flag, stop_flag = model_save_check(history=target_regression_eval_val_history,
                                                metric_name=metric_name,
                                                tolerance_count=10,
                                                reset_count=reset_count)
        if save_flag:
            torch.save(target_regressor.state_dict(),
                       os.path.join(task_save_folder, f'target_regressor_{seed}.pt'))

            torch.save(target_regressor.encoder.state_dict(),
                       os.path.join(task_save_folder, f'ft_encoder_{seed}.pt'))

        if stop_flag:
            try:
                ind = encoder_module_indices.pop()
                print(f'Unfreezing {epoch}')
                target_regressor.load_state_dict(
                    torch.load(os.path.join(task_save_folder, f'target_regressor_{seed}.pt')))

                target_regression_params.append(list(target_regressor.encoder.modules())[ind].parameters())
                lr = lr * kwargs['decay_coefficient']
                target_regression_optimizer = torch.optim.AdamW(chain(*target_regression_params), lr=lr)
                reset_count += 1
            except IndexError:
                break

    target_regressor.load_state_dict(
        torch.load(os.path.join(task_save_folder, f'target_regressor_{seed}.pt')))

    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=val_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     cv_flag=True,
                                     output_folder=kwargs['model_save_folder'])
    evaluate_target_regression_epoch(regressor=target_regressor,
                                     dataloader=test_dataloader,
                                     device=kwargs['device'],
                                     history=None,
                                     seed=seed,
                                     output_folder=kwargs['model_save_folder'])

    return target_regressor, (target_regression_train_history, target_regression_eval_train_history,
                              target_regression_eval_val_history, target_regression_eval_test_history)
示例#8
0
                               name='sampling_prob')

# define CNN feature-extractor
cfe = CryptoFeatureExtractorController(input_length=INPUT_SEQ_LEN,
                                       input_labels=train.input_labels,
                                       kernel_sizes=KERNEL_SIZES,
                                       kernel_filters=KERNEL_FILTERS,
                                       output_size=CNN_OUTPUT_SIZE)

# build CNN
new_sequence, new_length = cfe.build(input_ph=input_sequence)

# define RNN encoder-decoder
encoder_decoder = EncoderDecoder(num_units=NUM_UNITS,
                                 num_layers=NUM_LAYERS,
                                 input_length=new_length,
                                 input_depth=CNN_OUTPUT_SIZE,
                                 target_length=TARGET_SEQ_LEN,
                                 target_depth=TARGET_DEPTH)

# build RNN
outputs = encoder_decoder.build(input_ph=new_sequence,
                                target_ph=target_sequence,
                                keep_prob=rnn_keep_prob,
                                sampling_prob=sampling_prob)

#=========================================================================
# TRAINING PARAMS
#=========================================================================

# saves hparam and model variables
model_saver = Saver()
示例#9
0
def train_cleitc(dataloader, seed, **kwargs):
    """

    :param s_dataloaders:
    :param t_dataloaders:
    :param kwargs:
    :return:
    """
    autoencoder = AE(input_dim=kwargs['input_dim'],
                     latent_dim=kwargs['latent_dim'],
                     hidden_dims=kwargs['encoder_hidden_dims'],
                     dop=kwargs['dop']).to(kwargs['device'])

    # get reference encoder
    aux_ae = deepcopy(autoencoder)

    aux_ae.encoder.load_state_dict(
        torch.load(os.path.join('./model_save', f'ft_encoder_{seed}.pt')))
    print('reference encoder loaded')
    reference_encoder = aux_ae.encoder

    # construct transmitter
    transmitter = MLP(input_dim=kwargs['latent_dim'],
                      output_dim=kwargs['latent_dim'],
                      hidden_dims=[kwargs['latent_dim']]).to(kwargs['device'])

    ae_eval_train_history = defaultdict(list)
    ae_eval_test_history = defaultdict(list)

    if kwargs['retrain_flag']:
        cleit_params = [autoencoder.parameters(), transmitter.parameters()]
        cleit_optimizer = torch.optim.AdamW(chain(*cleit_params),
                                            lr=kwargs['lr'])
        # start autoencoder pretraining
        for epoch in range(int(kwargs['train_num_epochs'])):
            if epoch % 50 == 0:
                print(f'----Autoencoder Training Epoch {epoch} ----')
            for step, batch in enumerate(dataloader):
                ae_eval_train_history = cleit_train_step(
                    ae=autoencoder,
                    reference_encoder=reference_encoder,
                    transmitter=transmitter,
                    batch=batch,
                    device=kwargs['device'],
                    optimizer=cleit_optimizer,
                    history=ae_eval_train_history)
        torch.save(autoencoder.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt'))
        torch.save(transmitter.state_dict(),
                   os.path.join(kwargs['model_save_folder'], 'transmitter.pt'))
    else:
        try:
            autoencoder.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'], 'cleit_ae.pt')))
            transmitter.load_state_dict(
                torch.load(
                    os.path.join(kwargs['model_save_folder'],
                                 'transmitter.pt')))
        except FileNotFoundError:
            raise Exception("No pre-trained encoder")

    encoder = EncoderDecoder(encoder=autoencoder.encoder,
                             decoder=transmitter).to(kwargs['device'])

    return encoder, (ae_eval_train_history, ae_eval_test_history)
示例#10
0
                    "--vocab",
                    type=str,
                    help="The file with the vocab characters.")
args = parser.parse_args()
input_dir = args.data
vocab_file = args.vocab
files = [
    os.path.join(input_dir, f) for f in os.listdir(input_dir)
    if f.endswith('.csv')
]

TH = TweetHandler(files, vocab_file)
TH.set_train_split()
TH.remove_urls()

if not os.path.isdir(output_dir):
    os.mkdir(output_dir)

save_params = (os.path.join(output_dir,
                            model_name), os.path.join(output_dir, log_name))

enc = EncoderDecoder(hidden_dim, TH, num_lstms)
enc.do_training(seq_len,
                batch_size,
                num_epochs,
                learning_rate,
                samples_per_epoch,
                teacher_force_frac,
                slice_incr_frequency=slice_incr_frequency,
                save_params=save_params)
示例#11
0
def main(_):
    # loading
    print("Loading data...")
    data_list = data_helpers.load_data_and_labels(FLAGS.data_file)

    print("Loading w2v")
    sentence, sentence_len, = [], []
    w2v = KeyedVectors.load_word2vec_format(FLAGS.word2vec_model, binary=False)
    vocab, embeddings = w2v.vocab, np.zeros((len(w2v.index2word), w2v.vector_size), dtype=np.float32)

    print("convert sentence to index")
    for k, v in vocab.items():
        embeddings[v.index] = w2v[k]

    max_len = -1
    for item in data_list:
        sentence_index = [w2v.vocab[word].index if word in w2v.vocab else w2v.vocab["__UNK__"].index
                          for word in item.split(" ")]
        sentence.append(sentence_index)
        length = len(sentence_index)
        sentence_len.append(length)
        if length > max_len:
            max_len = length
    # 补padding,不然数据feed不进去
    for item in sentence:
        item.extend([0] * (max_len - len(item)))
    print("Vocabulary Size: {:d}".format(len(w2v.vocab)))

    # save path
    timestamp = str(int(time.time()))
    out_dir = os.path.abspath(os.path.join(os.path.curdir, "runs/enc_dec/", timestamp))
    print("Writing to {}\n".format(out_dir))

    # checkpoint
    checkpoint_dir = os.path.abspath(os.path.join(out_dir, "checkpoints"))
    checkpoint_prefix = os.path.join(checkpoint_dir, "model")
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    # build graph
    with tf.Graph().as_default():
        sess = tf.InteractiveSession()

        enc_dec = EncoderDecoder(
            embeddings=embeddings,
            encoder_hidden_size=64,
            decoder_hidden_size=64, )
        # train op
        global_step, optimizer = tf.Variable(0, name="global_step", trainable=False), tf.train.AdamOptimizer(1e-3)
        grads_and_vars = optimizer.compute_gradients(enc_dec.loss)
        train_op = optimizer.apply_gradients(grads_and_vars, global_step=global_step)

        # Summaries
        loss_summary = tf.summary.scalar("loss", enc_dec.loss)
        acc_summary = tf.summary.scalar("accuracy", enc_dec.accuracy)

        summary_op, summary_dir = tf.summary.merge([loss_summary, acc_summary]), os.path.join(out_dir, "summaries")
        summary_writer = tf.summary.FileWriter(summary_dir, sess.graph)

        # saver
        saver = tf.train.Saver(tf.global_variables(), max_to_keep=FLAGS.num_checkpoints, save_relative_paths=True)

        # init
        sess.run(tf.global_variables_initializer())

        def train_step(sentence, sentence_len):
            """
            A single training step
            """
            feed_dict = {
                enc_dec.encoder_inputs: sentence,
                enc_dec.decoder_inputs: sentence,
                enc_dec.sequence_length: sentence_len
            }
            _, step, summaries, loss, acc = sess.run(
                [train_op, global_step, summary_op, enc_dec.loss, enc_dec.accuracy],
                feed_dict)
            time_str = datetime.datetime.now().isoformat()
            print("{}: step {}, loss {:g}, acc {:g}".format(time_str, step, loss, acc))
            summary_writer.add_summary(summaries, step)

        # do train
        batches = data_helpers.batch_iter(list(zip(sentence, sentence_len)), FLAGS.batch_size, FLAGS.num_epochs)
        for batch in batches:
            train_sentence, train_sentence_len = zip(*batch)
            train_step(list(train_sentence), list(train_sentence_len))
            current_step = tf.train.global_step(sess, global_step)
            if current_step % FLAGS.checkpoint_every == 0:
                path = saver.save(sess, checkpoint_prefix, global_step=current_step)
                print("Saved model checkpoint to {}\n".format(path))