Пример #1
0
 def test_save_and_load(self):
     
     path_checkpoint='.'
     prefix_checkpoint='test'
     n_epochs=10
     cor = Corgan()
     
     # dummy dataset
     n_gen = 500
     n = 1000
     m = 7
     x = np.random.randint(low=0, high=2, size=(n,m))
     
     model_saved = cor.train(x=x, 
                          n_epochs_pretrain=10,
                          n_epochs=10,
                          path_checkpoint=path_checkpoint, 
                          prefix_checkpoint=prefix_checkpoint)
     
     file = 'test.pkl'
     cor.save_obj(obj=model_saved, file_name=file)
     model_loaded = cor.load_obj(file)
     x_synth = cor.generate(model = model_loaded, n_gen=n_gen)
     
     # clean up 
     file_ckpt=os.path.join(path_checkpoint, prefix_checkpoint + ".model_epoch_%d.pth" % n_epochs)
     os.remove(file_ckpt)
     os.remove(file)
     
     assert len(x_synth) == n_gen
Пример #2
0
    if args.train_type == 'corgan':
        syn = Corgan(debug=debug, n_cpu=args.n_cpu_train)
    elif args.train_type == 'ppgan':
        syn = Ppgan(debug=debug, n_cpu=args.n_cpu_train)
    model = syn.train(x=r_trn, n_epochs=args.n_epoch)
    model['m'] = meta
    model['header'] = obj_d['header']
    syn.save_obj(model, outfile)

elif args.task == 'generate':

    pre = Preprocessor(missing_value=args.missing_value_generate)
    outfile = args.outprefix_generate + '.csv'

    syn = Corgan()
    model = syn.load_obj(args.file_model)
    if model['parameter_dict']['model'] == 'corgan':
        syn = Corgan()
    elif model['parameter_dict']['model'] == 'ppgan':
        syn = Ppgan(debug=False, n_cpu=1)

    s = syn.generate(model, n_gen=args.generate_size)

    f = pre.restore_matrix(arr=s, meta=model['m'], header=model['header'])
    np.savetxt(fname=outfile,
               fmt='%s',
               X=f['x'],
               delimiter=',',
               header=','.join(f['header']),
               comments='')