示例#1
0
def train_discriminator_wrapper(x, x_gen, batch_size=1, vocab_size=10):
    y = gen_label(len(x),fixed_value=1)
    y_gen = gen_label(len(x_gen),fixed_value=0)
    x_train = torch.cat([x.int(),x_gen.int()], dim=0)
    y_train = torch.cat([y,y_gen], dim=0)
    discriminator = train_discriminator(x_train, y_train, batch_size, vocab_size)
    return discriminator
示例#2
0
def sanityCheck_rewards(batch_size=5):
    ''' test rewards generation '''
    log = openLog('test.txt')
    log.write('\n\nTest rollout.sanityCheck_rewards: {}'.format(
        datetime.now()))
    try:
        generator, _, y_output_all = sanityCheck_generator(
            batch_size=batch_size, sample_size=batch_size * 2)
        gen_output = y_output_all[-batch_size:, :]
        rollout = Rollout(generator=generator)
        rollout = nn.DataParallel(rollout)
        rollout.to(DEVICE)
        discriminator = train_discriminator(
            batch_size=batch_size,
            vocab_size=generator.pretrain_model.module.vocab_size)
        rewards = getReward(gen_output, rollout, discriminator)
        log.write('\n  rollout.sanityCheck_rewards SUCCESSFUL. {}\n'.format(
            datetime.now()))
        log.close()
        return rewards
    except:
        log.write(
            '\n  rollout.sanityCheck_rewards !!!!!! UNSUCCESSFUL !!!!!! {}\n'.
            format(datetime.now()))
        log.close()
        return None
示例#3
0
def train(  num_generations,            generation_size,    starting_selfies,   max_molecules_len,
            disc_epochs_per_generation, disc_enc_type,      disc_layers,        training_start_gen,           
            device,                     properties_calc_ls, num_processors,     beta, 
            max_fitness_collector,      impose_time_adapted_pen):
    
    # Collecting all generated molecules and counting occurances
    smiles_all         = []
    selfies_all        = []
    smiles_all_counter = {}
    
    #Preparing generator and discriminator
    initial_population = du.sanitize_multiple_smiles([decoder(selfie) for selfie in starting_selfies]) 
    reference_mols = dict.fromkeys(du.read_dataset_encoding(disc_enc_type), '')
    discriminator, d_opt, d_loss = D.build_discriminator(disc_enc_type, disc_layers, max_molecules_len, device)

    total_time = time.time()
    for generation_index in range(1, num_generations+1):
        print(f"###Generation {generation_index} of {num_generations}")
        start_time = time.time()
              
        # Molecules of previous population 
        smiles_here, selfies_here = G.get_prev_gen_mols(initial_population,   starting_selfies, generation_size, 
                                                                     generation_index,  selfies_all,      smiles_all)

        # Fitness of previous population
        fitness_here, order, fitness_ordered, smiles_ordered, selfies_ordered = G.get_fitness(   disc_enc_type,      smiles_here,   selfies_here, 
                                                                                                    properties_calc_ls, discriminator, generation_index,
                                                                                                    max_molecules_len,  device,        generation_size,  
                                                                                                    num_processors,     writer,        beta,            
                                                                                                    image_dir,          data_dir,      max_fitness_collector, impose_time_adapted_pen
                                                                                                    )

        # Molecules to replace and to keep
        to_replace, to_keep = G.cutoff_population(order, generation_size)
        
        # Molecules of current population 
        smiles_mutated, selfies_mutated = G.get_next_gen_mols(order,           to_replace,     to_keep, 
                                                                             selfies_ordered, smiles_ordered, max_molecules_len)
        # Results of generator
        smiles_all, selfies_all, smiles_all_counter = G.update_gen_res(smiles_all, smiles_mutated, selfies_all, selfies_mutated, smiles_all_counter)

        # Data for discriminator training
        dataset_x, dataset_y = G.get_dis_data(disc_enc_type, reference_mols, smiles_mutated, selfies_mutated, max_molecules_len, num_processors, generation_index)
        
        if generation_index >= training_start_gen:
            discriminator = D.train_discriminator(dataset_x, dataset_y, discriminator, d_opt, d_loss , disc_epochs_per_generation, generation_index-1, device, writer, data_dir)
            D.save_model(discriminator, generation_index-1, saved_models_dir)

        print('Generation time: ', round((time.time()-start_time), 2), ' seconds')

    print('Experiment time: ', round((time.time()-total_time)/60, 2), ' mins')
    print('Number of unique molecules: ', len(smiles_all_counter))
    return smiles_all_counter
示例#4
0
def train_adversarial(sess, saver, MODEL_STRING, generator, discriminator, 
                      rollout, dis_data_loader, likelihood_data_loader, 
                      task, log, n):
    print('#################################################################')
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))
    small_loss = float('inf')
    for total_batch in range(n):
        # Train the generator for one step
        samples = generator.generate(sess)
        rewards = rollout.get_reward(sess, samples, 16, discriminator) #I might actually need to change the value 16 here.
        feed = {generator.x: samples, generator.rewards: rewards}
        _ = sess.run(generator.g_updates, feed_dict=feed)

        # Test
        sample = generate_samples(sess, generator, BATCH_SIZE, 
                                  task.generated_num, task.eval_file)
        print("Examples from generator:")
        for sample in task.vocab.decode(samples)[:5]:
            print(sample)

        likelihood_data_loader.create_batches(task.valid_file)
        test_loss = target_loss(sess, generator, likelihood_data_loader)
        if test_loss < small_loss:
            small_loss = test_loss
            saver.save(sess, MODEL_STRING +"/model")
            print("Saving checkpoint ...")
        print("total_batch: ", total_batch, "test_loss: ", test_loss)
        buffer = "total_batch: " + str(total_batch) + "test_loss: " + str(test_loss)
        log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator for 5 steps
        train_discriminator(sess, generator, discriminator, dis_data_loader, 
                            task, log, 5, BATCH_SIZE, task.generated_num,
                            dis_dropout_keep_prob)
示例#5
0
def main():
    
    #  Create a parser to parse user input
    def parse_arguments():
        parser = argparse.ArgumentParser(description='Program for running several SeqGan applications.')
        parser.add_argument('app', metavar='application', type=str, choices=['obama', 'haiku', 'synth'],
                        help='Enter either \'obama\' or \'haiku\'')
        parser.add_argument('gen_n', type = int,
                        help='Number of generator pre-training steps')
        parser.add_argument('disc_n', type = int,
                        help='Number of discriminator pre-training steps')
        parser.add_argument('adv_n', type = int,
                        help='Number of adversarial pre-training steps')
        parser.add_argument('-mn', metavar="model_name", type = str, default = "",
                        help = "Name for the checkpoint files. Will be stored at ./<app>/models/<model_name>")
        parser.add_argument('-numeat', metavar="num_eat", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of eaters in vocab.")
        parser.add_argument('-numfeed', metavar="num_feed", type = int, default = 500,
                        help = "For synthetic data generation. Determines number of feeders in vocab.")
        parser.add_argument('-numsent', metavar="num_sent", type = int, default = 10000,
                        help = "For synthetic data generation. Determines number of sentences generated.")
        args = parser.parse_args()

        synth_gen_params = ("NA", "NA", "NA")
        if args.app == "synth":
            synth_gen_params = (args.numsent, args.numfeed, args.numeat)
            generate_random_sents("../data/synth/input.txt", args.numsent, args.numfeed, args.numeat)

        task = load_task(args.app)

        #Make the /models directory if its not there.
        model_string = task.path +"/models/"
        if not os.path.exists("./"+model_string):
            os.mkdir("./"+model_string)
    
        #make the checkpoint directory if its not there.
        if args.mn == "":
            model_string += str(args.gen_n)+ "_" + str(args.disc_n) + "_" + str(args.adv_n)
            model_string += time.strftime("_on_%m_%d_%y", time.gmtime())
        else:
            model_string += args.mn
        if not os.path.exists("./"+model_string):
            os.mkdir("./"+model_string)
    
        return args.gen_n, args.disc_n, args.adv_n, model_string, task, synth_gen_params
    
    gen_n, disc_n, adv_n, MODEL_STRING, task, SYNTH_GEN_PARAMS = parse_arguments()


    assert START_TOKEN == 0

    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

    # Initialize the data loaders
    gen_data_loader = Gen_Data_loader(BATCH_SIZE, task.max_seq_length)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE, task.max_seq_length) # For validation
    dis_data_loader = Dis_dataloader(BATCH_SIZE, task.max_seq_length)

    # Initialize the Generator
    generator = Generator(len(task.vocab), BATCH_SIZE, EMB_DIM, HIDDEN_DIM, 
                          task.max_seq_length, START_TOKEN)

    # Initialize the Discriminator
    discriminator = Discriminator(sequence_length=task.max_seq_length, 
                                  num_classes=2, 
                                  vocab_size=len(task.vocab), 
                                  embedding_size=dis_embedding_dim, 
                                  filter_sizes=dis_filter_sizes, 
                                  num_filters=dis_num_filters, 
                                  l2_reg_lambda=dis_l2_reg_lambda)

    # Set session configurations. 
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    saver = tf.train.Saver()
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())

    # If restoring from a previous run ....
    if len(os.listdir("./"+MODEL_STRING)) > 0:
        saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))


    # Create batches from the positive file.
    gen_data_loader.create_batches(task.train_file)

    # Open log file for writing
    log = open(task.log_file, 'w')

    # Pre_train the generator with MLE. 
    pre_train_generator(sess, saver, MODEL_STRING, generator, gen_data_loader, 
                        likelihood_data_loader, task, log, gen_n, BATCH_SIZE,
                        task.generated_num)
    print('Start pre-training discriminator...')

    # Do the discriminator pre-training steps
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))
    train_discriminator(sess, generator, discriminator, dis_data_loader, 
                        task, log, disc_n, BATCH_SIZE, task.generated_num,
                        dis_dropout_keep_prob)
    print("Saving checkpoint ...")
    saver.save(sess, MODEL_STRING+ "/model")
    
    # Do the adversarial training steps
    rollout = ROLLOUT(generator, 0.8)
    train_adversarial(sess, saver, MODEL_STRING, generator, discriminator, 
                      rollout, dis_data_loader, likelihood_data_loader, 
                      task, log, adv_n)

    #Use the best model to generate final sample
    saver.restore(sess, tf.train.latest_checkpoint(MODEL_STRING))
    generate_samples(sess, generator, BATCH_SIZE, task.generated_num, 
                     task.eval_file)


    #Writing results to CSV
    with open(task.eval_file) as f:
        generated = []
        for line in f:
            line = line.strip().split()
            generated.append(line)
        generated = task.vocab.decode(generated)
        f.close()

    with open(task.test_file) as f:
        references = []
        for line in f:
            line = line.strip().split()
            references.append(line)
        references = task.vocab.decode(references)  
        f.close()      

    blue = corpus_bleu([references]*len(generated), generated)
    print("Run with args {} {} {}: BLEUscore = {}\n".format(gen_n, disc_n, adv_n, blue))
    
    prop = "NA"

    if task.name == "synth":
        total_correct = 0
        for sentence in generated:
            if is_valid_phrase(sentence):
                total_correct +=1
        prop = total_correct/len(generated)
        
    if not os.path.exists("./results.csv"):
        os.mknod("./results.csv")

    with open("./results.csv", 'a') as csvfile:
        fieldnames = ["name", "task_name", "num_gen", "num_disc", "num_adv",
                    "num_sents", "num_feeders", "num_eaters", "BLEU", "prop_valid"]
        writer = csv.DictWriter(csvfile, fieldnames = fieldnames)
        writer.writeheader()
        writer.writerow({"name": MODEL_STRING, "task_name": task.name,  "num_gen": gen_n, 
                        "num_disc":disc_n, "num_adv": adv_n, "num_sents":SYNTH_GEN_PARAMS[0],
                        "num_feeders":SYNTH_GEN_PARAMS[1], "num_eaters":SYNTH_GEN_PARAMS[2],
                        "BLEU": blue, "prop_valid": prop})
        f.close()


    log.close()
示例#6
0
from graph_freezer import freezing_graph
from discriminator import train_discriminator
from estimator import train_estimator

discriminator = train_discriminator()
estimator = train_estimator()

freezing_graph("discriminator")
freezing_graph("estimator")