예제 #1
0
def test_essay_batcher_2():
    char = True
#     char = False

    U.seed_random(1234)
    keep_unk = False
    
    emb_dim = 100
    emb_path = '/home/david/data/embed/glove.6B.{}d.txt'
    vocab_dir = '/home/david/data/ets1b/2016'
    vocab_file = os.path.join(vocab_dir, 'vocab_n250.txt')
    
    data_dir = '/home/david/data/ats/ets'
    id = 55433; essay_file = os.path.join(data_dir, '{0}', 'text.txt').format(id)
    #     id = 63986; essay_file = os.path.join(data_dir, '{0}', '{0}.txt.clean.tok').format(id)
    
    reader =  GlobReader(essay_file, chunk_size=1000, regex=REGEX_NUM, shuf=True)
    
    if char:
        word_vocab, char_vocab, max_word_length = Vocab.load_vocab(vocab_file)
        text_parser = TextParser(word_vocab=word_vocab, char_vocab=char_vocab, keep_unk=keep_unk)
    else:
        E, word_vocab = Vocab.load_word_embeddings_ORIG(emb_path, emb_dim, essay_file, min_freq=5)
        #E, word_vocab = Vocab.load_word_embeddings(emb_file, essay_file, min_freq=2)
        text_parser = TextParser(word_vocab=word_vocab, keep_unk=keep_unk)
    
    fields = {0:'id', 1:'y', -1:text_parser}
    field_parser = FieldParser(fields, reader=reader, seed=1234)
    
    batcher = EssayBatcher(reader=field_parser, batch_size=32, trim_words=True)
    for b in batcher.batch_stream(stop=True, split_sentences=True):
        print('{}\t{}'.format(b.w.shape, b.y.shape))
예제 #2
0
def main(config, dataset_root, resume):
    set_print_precision()
    seed_random()

    config = get_config(config)
    train_loader = get_data_loader(config, dataset_root, is_train=True)
    test_loader = get_data_loader(config, dataset_root, is_train=False)

    # model
    model = StarGAN(config, train_loader, test_loader)

    if not resume:
        model.train_starGAN(init_epoch=0)
    else:
        model.resume_train()  # resume train after the last saved epoch model
예제 #3
0
파일: test.py 프로젝트: acan007/sinGAN
def main(config, mode):
    seed_random()
    set_numpy_precision()

    if mode == 'random_sample':
        model_type = SinGAN
    elif mode == 'paint2image':
        model_type = Paint2Image
    elif mode == 'editing':
        model_type = Editing
    elif mode == 'harmonization':
        model_type = Harmonization
    elif mode == 'SR':
        model_type = SuperResolution
    else:
        print("Invalid Parameter")
        raise ValueError

    model = model_type(get_config(config))
    model.test_samples(save=True)
예제 #4
0
def test_response_batcher():
    char = True
    char = False

    U.seed_random(1234)
    keep_unk = False
    
    emb_dir = '/home/david/data/embed'; emb_file = os.path.join(emb_dir, 'glove.6B.100d.txt')
    emb_path = '/home/david/data/embed/glove.6B.{}d.txt'
    
    vocab_dir = '/home/david/data/ets1b/2016'
    vocab_file = os.path.join(vocab_dir, 'vocab_n250.txt')
    
    data_dir = '/home/david/data/ats/ets'
    id = 56375; 
    essay_file = os.path.join(data_dir, '{0}/text.txt').format(id)
#     essay_file = os.path.join(data_dir, '{0}/{0}.txt').format(id)
    
    reader =  GlobReader(essay_file, chunk_size=1000, regex=REGEX_NUM, shuf=False)
    
    if char:
        word_vocab, char_vocab, max_word_length = Vocab.load_vocab(vocab_file)
        text_parser = TextParser(word_vocab=word_vocab, char_vocab=char_vocab, keep_unk=keep_unk)
    else:
        E, word_vocab = Vocab.load_word_embeddings_ORIG(emb_path, 100, essay_file, min_freq=5)
        #E, word_vocab = Vocab.load_word_embeddings(emb_file, essay_file, min_freq=2)
        text_parser = TextParser(word_vocab=word_vocab, keep_unk=keep_unk)
    
    fields = {0:'id', 1:'y', -1:text_parser}
    field_parser = FieldParser(fields, reader=reader, seed=1234)
    
    tot_vol = 0
    batcher = ResponseBatcher(reader=field_parser, batch_size=64, trim_words=True)
    for b in batcher.batch_stream(stop=True,
                                  split_sentences=True,
                                  spad='post',
                                  wpad='post',
                                  ):
        vol = np.prod(b.x.shape); tot_vol+=vol
        print('{}\t{}\t{}'.format(vol, b.x.shape, b.y.shape))
    print('tot_vol:\t{}'.format(tot_vol))
예제 #5
0
def parse_config(config_file, parser):
    #parser = options.get_parser()
    argv=[]# override config file here
    FLAGS = get_config(parser=parser, config_file=config_file, argv=argv)
    FLAGS.chkpt_dir = U.make_abs(FLAGS.chkpt_dir)
    
    if FLAGS.load_model:
        if FLAGS.load_chkpt_dir:
            FLAGS.load_chkpt_dir = U.make_abs(FLAGS.load_chkpt_dir)
        else:
            FLAGS.load_chkpt_dir = FLAGS.chkpt_dir
    else:
        if FLAGS.model=='HANModel':
            FLAGS.epoch_unfreeze_word = 0
    
    FLAGS.cwd = os.getcwd()
    FLAGS.log_file = os.path.abspath(os.path.join(FLAGS.cwd, 'log.txt'))
    
    FLAGS.rand_seed = U.seed_random(FLAGS.rand_seed)
    
    if FLAGS.id_dir is None:
        FLAGS.id_dir = FLAGS.data_dir
    else:
        FLAGS.id_dir = os.path.join(FLAGS.data_dir, FLAGS.id_dir).format(FLAGS.item_id)
        
    if FLAGS.attn_size>0:
        FLAGS.mean_pool = False
        if FLAGS.attn_type<0:
            FLAGS.attn_type=0
            
    if FLAGS.embed_type=='word':
        FLAGS.model_std = None
        FLAGS.attn_std = None
    
    #### test ids
    test_ids, test_id_file = None, None
    FLAGS.test_y, FLAGS.test_yint = None, None
    
    if FLAGS.test_pat is None:
        FLAGS.save_test = None
        FLAGS.load_test = None
    else:
        trait = ''
        if FLAGS.trait is not None:
            trait = '_{}'.format(FLAGS.trait)
        test_id_file = os.path.join(FLAGS.data_dir, FLAGS.test_pat).format(FLAGS.item_id, trait)
        #########################
        if FLAGS.load_test and U.check_file(test_id_file):
            
            #################################
            data = U.read_cols(test_id_file)
            test_ids = data[:,0]
            
            if test_ids.dtype.name.startswith('float'):
                test_ids = test_ids.astype('int32')
            test_ids = test_ids.astype('unicode')
                
            if data.shape[1]>1:
                FLAGS.test_yint = data[:,1].astype('int32')
                FLAGS.test_y = data[:,2].astype('float32')
            
        #########################
        if FLAGS.save_test and test_ids is not None:
            FLAGS.save_test = False
#     FLAGS.test_ids = set(test_ids) if test_ids is not None else []
    FLAGS.test_ids = test_ids if test_ids is not None else []
    FLAGS.test_id_file = test_id_file
    
    ''' don't overwrite MLT test ids!!! '''
    if 'test_ids' in FLAGS.test_id_file:
        FLAGS.save_test = False
        
    #### valid ids
    valid_ids, valid_id_file = None, None
    if FLAGS.valid_pat is None:
        FLAGS.save_valid = None
        FLAGS.load_valid = None
    else:
        trait = ''
        if FLAGS.trait is not None:
            trait = '_{}'.format(FLAGS.trait)
        valid_id_file = os.path.join(FLAGS.data_dir, FLAGS.valid_pat).format(FLAGS.item_id, trait)
        if FLAGS.load_valid:
            valid_ids = get_ids(valid_id_file)
        if FLAGS.save_valid and valid_ids is not None:
            FLAGS.save_valid = False
    #FLAGS.valid_ids = set(valid_ids) if valid_ids is not None else []
    FLAGS.valid_ids = valid_ids if valid_ids is not None else []
    FLAGS.valid_id_file = valid_id_file
    
    #### train ids
    train_ids, train_id_file =None, None
    if FLAGS.train_pat:
        trait = ''
        if FLAGS.trait is not None:
            trait = '_{}'.format(FLAGS.trait)
        train_id_file = os.path.join(FLAGS.data_dir, FLAGS.train_pat).format(FLAGS.item_id, trait)
        train_ids = get_ids(train_id_file, default=[])
    #FLAGS.train_ids = set(train_ids) if train_ids is not None else []
    FLAGS.train_ids = train_ids if train_ids is not None else []
    FLAGS.train_id_file = train_id_file
    
    ###################################
    FLAGS.embed = U.adict({'type':FLAGS.embed_type, 
                           'char':FLAGS.embed_type=='char', 
                           'word':FLAGS.embed_type=='word' })
    
    FLAGS.word_embed_dir = os.path.join(FLAGS.embed_dir, 'word')
    FLAGS.char_embed_dir = os.path.join(FLAGS.embed_dir, 'char')
    
    feats = ['kernel_widths','kernel_features','rnn_cells','rnn_sizes','rnn_bis','attn_sizes','attn_depths','attn_temps','pads','learning_rates']
    for feat in feats:
        if feat in FLAGS and FLAGS[feat]:
            FLAGS[feat] = eval(eval(FLAGS[feat]))
    FLAGS.wpad = FLAGS.pads[0]
    FLAGS.spad = None if len(FLAGS.pads)<2 else FLAGS.pads[1]
    
    if FLAGS.attn_depths[0]>1 or (len(FLAGS.attn_depths)>1 and FLAGS.attn_depths[1]>1):
        FLAGS.attn_vis=False
        
    if FLAGS.attn_sizes[0]<1:
        FLAGS.attn_vis=False
    
    if FLAGS.embed.char:
        FLAGS.attn_vis=False
    
    ###################################
    return FLAGS
예제 #6
0
def run():
    plotter = Plotter(show_avgs=True, plot_elites=False)
    fig_1 = 1
    fig_2 = 1
    fig_3 = 1
    fig_4 = 1
    fig_5 = 1
    extension = 1

    use_fixed_seed = 1
    enable_hof = 0
    enable_virulence = 0
    generations = 600

    export_folder = "../report/plots/"
    generator_a = Generator(100, 1)
    generator_b = Generator(10, 10)
    generator_c = Generator(50, 2)
    mutator = Mutator(mutation_rate=0.005, bit_flip=False)
    selector = FitnessProportionateSelection()
    n = 25

    if enable_hof:
        hof = HOF(scorer=Scorer(sample_size=10), size=50)
    else:
        hof = None

    # Wrap selector with virulence handling
    if enable_virulence:
        selector = VirulenceSelector(selector, 0.75, normalise=True)

    if fig_1:
        seed = seed_random(0, use_fixed_seed)
        description = "Figure 1 : [Seed {}]".format(seed)
        print(description)
        pop_a = generator_a.population(n, 0)
        pop_b = generator_a.population(n, 100)
        executor = Coevolution(scorer=F0Scorer(), selector=selector)
        executor.run(pop_a, pop_b, generations)
        plotter.make_plot(
            executor,
            fig_name=description,
            export_path=os.path.join(export_folder, "fig1.png"),
        )

    if fig_2:
        seed = seed_random(1, use_fixed_seed)
        description = "Figure 2 : [Seed {}]".format(seed)
        print(description)
        pop_a = generator_a.population(n, 0)
        pop_b = generator_a.population(n, 0)
        executor = Coevolution(mutator=mutator, hof=hof, selector=selector)
        executor.run(pop_a, pop_b, generations)
        plotter.make_plot(
            executor,
            fig_name=description,
            export_path=os.path.join(export_folder, "fig2.png"),
        )

    if fig_3:
        seed = seed_random(8486058433753192762, use_fixed_seed)
        description = "Figure 3 : [Seed {}]".format(seed)
        print(description)
        pop_a = generator_a.population(n, 0)
        pop_b = generator_a.population(n, 0)
        executor = Coevolution(
            mutator=mutator, scorer=Scorer(sample_size=1), hof=hof, selector=selector
        )
        executor.run(pop_a, pop_b, generations)
        plotter.make_plot(
            executor,
            fig_name=description,
            export_path=os.path.join(export_folder, "fig3.png"),
        )

    if fig_4:
        seed = seed_random(59759543964706904, use_fixed_seed)
        description = "Figure 4 : [Seed {}]".format(seed)
        print(description)
        pop_a = generator_b.population(n, 0)
        pop_b = generator_b.population(n, 0)
        executor = Coevolution(mutator=mutator, hof=hof, selector=selector)
        executor.run(pop_a, pop_b, generations)
        plotter.make_plot(
            executor,
            fig_name=description,
            export_path=os.path.join(export_folder, "fig4.png"),
        )

    if fig_5:
        seed = seed_random(5706501168717675099, use_fixed_seed)
        description = "Figure 5 : [Seed {}]".format(seed)
        print(description)
        pop_a = generator_c.population(n, 0)
        pop_b = generator_c.population(n, 0)
        executor = Coevolution(
            mutator=mutator,
            scorer=Scorer(intransitive=True),
            hof=hof,
            selector=selector,
        )
        executor.run(pop_a, pop_b, generations)
        plotter.make_plot(
            executor,
            fig_name=description,
            export_path=os.path.join(export_folder, "fig5.png"),
        )

    if extension:
        hof = HOF(scorer=Scorer(sample_size=10), size=50)
        generations = 1200
        fp_selector = FitnessProportionateSelection()
        sus_selector = StochasticUniversalSampling()
        t_selector = TournamentSelection(3)
        ext_cfgs = []
        ext_cfgs += [
            (
                "Virulence [0.5]",
                "fig_5_v0.5.png",
                VirulenceSelector(fp_selector, 0.5),
                False,
            )
        ]
        ext_cfgs += [
            (
                "Virulence [0.75]",
                "fig_5_v0.75.png",
                VirulenceSelector(fp_selector, 0.75),
                False,
            )
        ]
        ext_cfgs += [
            (
                "Virulence [0.75] + SUS",
                "fig_5_v0.75_sus.png",
                VirulenceSelector(sus_selector, 0.75),
                False,
            )
        ]
        ext_cfgs += [
            (
                "Virulence [0.75] + TS",
                "fig_5_v0.75_ts.png",
                VirulenceSelector(t_selector, 0.75),
                False,
            )
        ]
        ext_cfgs += [
            (
                "Virulence [0.75] + SUS + HOF",
                "fig_5_v0.75_sus_hof.png",
                VirulenceSelector(sus_selector, 0.75),
                True,
            )
        ]
        ext_cfgs += [
            (
                "Virulence [0.75] + HOF",
                "fig_5_v0.75_hof.png",
                VirulenceSelector(fp_selector, 0.75),
                True,
            )
        ]
        ext_cfgs += [("HOF", "fig_5_hof.png", fp_selector, True)]
        for (cfg_txt, export_name, selector, use_hof) in ext_cfgs:
            seed = seed_random(8985012493578745191, use_fixed_seed)
            description = "Figure 5 - {} : [Seed {}]".format(cfg_txt, seed)
            print(description)
            pop_a = generator_c.population(n, 0)
            pop_b = generator_c.population(n, 0)
            executor = Coevolution(
                mutator=mutator,
                scorer=Scorer(intransitive=True),
                selector=selector,
                hof=hof if use_hof else None,
            )
            executor.run(pop_a, pop_b, generations)
            plotter.make_plot(
                executor,
                fig_name=description,
                export_path=os.path.join(export_folder, export_name),
            )