Ejemplo n.º 1
0
def train(args):
  import models
  import numpy as np
  # np.random.seed(1234)

  if args.dataset == 'mnist':
    n_dim, n_out, n_channels = 28, 10, 1
    X_train, y_train, X_val, y_val, _, _ = data.load_mnist()
  elif args.dataset == 'random':
    n_dim, n_out, n_channels = 2, 2, 1
    X_train, y_train = data.load_noise(n=1000, d=n_dim)
    X_val, y_val = X_train, y_train
  else:
    raise ValueError('Invalid dataset name: %s' % args.dataset)

  # set up optimization params
  opt_params = { 'lr' : args.lr, 'c' : args.c, 'n_critic' : args.n_critic }

  # create model
  if args.model == 'dcgan':
    model = models.DCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params)
  elif args.model == 'wdcgan':
    model = models.WDCGAN(n_dim=n_dim, n_chan=n_channels, opt_params=opt_params)    
  else:
    raise ValueError('Invalid model')
  
  # train model
  model.fit(X_train, X_val, 
            n_epoch=args.epochs, n_batch=args.n_batch,
            logdir=args.logdir)
Ejemplo n.º 2
0
def train(args):

    #load dataset
    if args.dataset == 'mnist':
        n_dim, n_out, n_channels = 28, 10, 1
        X_train, y_train, X_val, y_val, _, _ = data.load_mnist()

    elif args.dataset == 'random':
        n_dim, n_out, n_channels = 2, 2, 1
        X_train, y_train = data.load_noise(n=1000, d=n_dim)
        X_val, y_val = X_train, y_train

    #可扩展
    elif args.dataset == 'malware_clean_data':
        n_dim, n_channels = 64, 1
        xtrain_mal, ytrain_mal, xtrain_ben, ytrain_ben, xtest_mal, ytest_mal, xtest_ben, ytest_ben = data.load_Malware_clean_ApkToImage(
        )
        if args.same_train_data:
            X_train, y_train, X_val, y_val = xtrain_mal, ytrain_mal, xtrain_ben, ytrain_ben
    else:
        raise ValueError('Invalid dataset name: %s' % args.dataset)

    # set up optimization params
    opt_params = {'lr': args.lr, 'c': args.c, 'n_critic': args.n_critic}

    # create model
    if args.model == 'dcgan':
        model = models.DCGAN(n_dim=n_dim,
                             n_chan=n_channels,
                             opt_params=opt_params)
    elif args.model == 'wdcgan':
        model = models.WDCGAN(n_dim=n_dim,
                              n_chan=n_channels,
                              opt_params=opt_params)
    else:
        raise ValueError('Invalid model')

    # train model

    model.fit(X_train,
              y_train,
              X_val,
              y_val,
              n_epoch=args.epochs,
              n_batch=args.n_batch,
              logdir=args.logdir)
Ejemplo n.º 3
0
def train(args):
    import models
    import numpy as np
    np.random.seed(1234)

    if args.dataset == 'mnist':
        n_dim, n_out, n_channels = 28, 10, 1
        X_train, Y_train, X_val, Y_val, _, _ = data.load_mnist()
    elif args.dataset == 'binmnist':
        n_dim, n_out, n_channels = 28, 10, 1
        X_train, X_val, _ = data.load_mnist_binarized()
        X_train = X_train.reshape((-1, 1, 28, 28))
        X_val = X_val.reshape((-1, 1, 28, 28))
        Y_train = np.empty((X_train.shape[0], ), dtype='int32')
        Y_val = np.empty((X_val.shape[0], ), dtype='int32')
    elif args.dataset == 'omniglot':
        n_dim, n_out, n_channels = 28, 10, 1
        X_train, Y_train, X_val, Y_val = data.load_omniglot_iwae()
        X_train = X_train.reshape((-1, 1, 28, 28))
        X_val = X_val.reshape((-1, 1, 28, 28))
        Y_train = np.empty((X_train.shape[0], ), dtype='int32')
        Y_val = np.empty((X_val.shape[0], ), dtype='int32')
    elif args.dataset == 'digits':
        n_dim, n_out, n_channels = 8, 10, 1
        X_train, Y_train, X_val, Y_val, _, _ = data.load_digits()
    else:
        X_train, Y_train = data.load_h5(args.train)
        X_val, Y_val = data.load_h5(args.test)

    # also get the data dimensions
    print 'dataset loaded.'

    # set up optimization params
    p = {'lr': args.lr, 'b1': args.b1, 'b2': args.b2, 'nb': args.n_batch}

    # create model
    if args.model == 'vae':
        model = models.VAE(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'discrete-vae':
        model = models.SBN(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'discrete-vae-rbm':
        model = models.USBN(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'adgm':
        model = models.ADGM(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'discrete-adgm':
        model = models.DADGM(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'discrete-adgm-rbm':
        model = models.UDADGM(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'rbm':
        model = models.RBM(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'vrbm':
        model = models.VariationalRBM(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    elif args.model == 'avrbm':
        model = models.AuxiliaryVariationalRBM(
            n_dim=n_dim,
            n_out=n_out,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
        )
    else:
        raise ValueError('Invalid model')

    if args.pkl:
        model.load(args.pkl)

        # generate samples
        samples = model.hallucinate_chain()

        # plot them
        plt.figure(figsize=(5, 5))
        plt.imshow(samples, cmap=plt.cm.gray, interpolation='none')
        plt.title('Hallucinated Samples')
        plt.tight_layout()
        plt.savefig(args.plotname)

        exit('hello')

    # train model
    model.fit(X_train,
              Y_train,
              X_val,
              Y_val,
              n_epoch=args.epochs,
              n_batch=args.n_batch,
              logname=args.logname)

    # generate samples
    samples = model.hallucinate()

    # plot them
    plt.figure(figsize=(5, 5))
    plt.imshow(samples, cmap=plt.cm.gray, interpolation='none')
    plt.title('Hallucinated Samples')
    plt.tight_layout()
    plt.savefig(args.plotname)
Ejemplo n.º 4
0
def train(args):
  import models
  import numpy as np
  np.random.seed(1234)

  if args.dataset == 'digits':
    n_dim, n_out, n_channels = 8, 10, 1
    X_train, y_train, X_val, y_val = data.load_digits()
  elif args.dataset == 'mnist':
    n_dim, n_out, n_channels = 28, 10, 1
    X_train, y_train, X_val, y_val, _, _ = data.load_mnist()
  elif args.dataset == 'svhn':
    n_dim, n_out, n_channels = 32, 10, 3
    X_train, y_train, X_val, y_val = data.load_svhn()
    X_train, y_train, X_val, y_val = data.prepare_dataset(X_train, y_train, X_val, y_val)
  elif args.dataset == 'cifar10':
    n_dim, n_out, n_channels = 32, 10, 3
    X_train, y_train, X_val, y_val = data.load_cifar10()
    X_train, y_train, X_val, y_val = data.prepare_dataset(X_train, y_train, X_val, y_val)
  elif args.dataset == 'random':
    n_dim, n_out, n_channels = 2, 2, 1
    X_train, y_train = data.load_noise(n=1000, d=n_dim)
    X_val, y_val = X_train, y_train
  else:
    raise ValueError('Invalid dataset name: %s' % args.dataset)
  print 'dataset loaded, dim:', X_train.shape

  # set up optimization params
  p = { 'lr' : args.lr, 'b1': args.b1, 'b2': args.b2 }

  # create model
  if args.model == 'softmax':
    model = models.Softmax(n_dim=n_dim, n_out=n_out, n_superbatch=args.n_superbatch, 
                           opt_alg=args.alg, opt_params=p)
  elif args.model == 'mlp':
    model = models.MLP(n_dim=n_dim, n_out=n_out, n_superbatch=args.n_superbatch, 
                       opt_alg=args.alg, opt_params=p)
  elif args.model == 'cnn':
    model = models.CNN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, model=args.dataset,
                       n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)  
  elif args.model == 'kcnn':
    model = models.KCNN(n_dim=n_dim, n_out=n_out, n_chan=n_channels, model=args.dataset,
                       n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)    
  elif args.model == 'resnet':
    model = models.Resnet(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)    
  elif args.model == 'vae':
    model = models.VAE(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')    
  elif args.model == 'convvae':
    model = models.ConvVAE(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')    
  elif args.model == 'convadgm':
    model = models.ConvADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')    
  elif args.model == 'sbn':
    model = models.SBN(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)      
  elif args.model == 'adgm':
    model = models.ADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p,
                          model='bernoulli' if args.dataset in ('digits', 'mnist') 
                                            else 'gaussian')
  elif args.model == 'hdgm':
    model = models.HDGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels, n_batch=args.n_batch,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)        
  elif args.model == 'dadgm':
    model = models.DADGM(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p) 
  elif args.model == 'dcgan':
    model = models.DCGAN(n_dim=n_dim, n_out=n_out, n_chan=n_channels,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)   
  elif args.model == 'ssadgm':
    X_train_lbl, y_train_lbl, X_train_unl, y_train_unl \
      = data.split_semisup(X_train, y_train, n_lbl=args.n_labeled)
    model = models.SSADGM(X_labeled=X_train_lbl, y_labeled=y_train_lbl, n_out=n_out,
                          n_superbatch=args.n_superbatch, opt_alg=args.alg, opt_params=p)
    X_train, y_train = X_train_unl, y_train_unl
  else:
    raise ValueError('Invalid model')
  
  # train model
  model.fit(X_train, y_train, X_val, y_val, 
            n_epoch=args.epochs, n_batch=args.n_batch,
            logname=args.logname)
Ejemplo n.º 5
0
            start_time = time.time()  # time evaluation

            # mini batch
            for i in range(0, (train_Size // BATCH_SIZE)):
                _, loss_ = sess.run(
                    fetches=(optimizer, loss),
                    feed_dict={X_p: X[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]})
                losses.append(loss_)

            print("loss:", sum(losses) / len(losses))

            #store the first pictre
            gen_pic = sess.run(fetches=generate, feed_dict={X_p: X})
            #print("gen_pic:\n",gen_pic)
            plt.imshow(np.reshape(X[3], newshape=[28, 28]))
            plt.imshow(np.reshape(gen_pic[3], newshape=[28, 28]))
            plt.show()


if __name__ == "__main__":
    print("Load Data....")
    X, y_train, X_test = data.load_mnist(reshape=False)
    X, X_test = data.standard(X_train=X, X_test=X_test)
    print(X.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(X[0])

    print("Run Model....")
    train(X)
Ejemplo n.º 6
0
def train(args):
    import models
    import numpy as np
    # np.random.seed(1234)

    if args.dataset == 'digits':
        n_dim_x, n_dim_y, n_out, n_channels = 8, 8, 10, 1
        X_train, y_train, X_val, y_val = data.load_digits()
    elif args.dataset == 'mnist':
        # load supservised data
        n_dim, n_aug, n_dim_x, n_dim_y, n_out, n_channels = 784, 0, 28, 28, 10, 1
        X_train, y_train, X_val, y_val, X_test, y_test = data.load_mnist()
        X_train = np.concatenate((X_train, X_val))
        y_train = np.concatenate((y_train, y_val))
        X_val, y_val = X_test, y_test
    elif args.dataset == 'svhn':
        n_dim, n_aug, n_dim_x, n_dim_y, n_out, n_channels = 3072, 0, 32, 32, 10, 3
        X_train, y_train, X_val, y_val = data.load_svhn()
    elif args.dataset == 'random':
        X_train, y_train = data.load_noise(n=1000, d=n_dim)
    else:
        raise ValueError('Invalid dataset name: %s' % args.dataset)
    print 'dataset loaded.'

    # set up optimization params
    p = {'lr': args.lr, 'b1': args.b1, 'b2': args.b2}

    # X_train_unl = X_train_unl.reshape((-1, n_channels, n_dim_x, n_dim_y))
    # X_test = X_test.reshape((-1, n_channels, n_dim_x, n_dim_y))

    # create model
    if args.model == 'supervised-mlp':
        model = models.SupervisedMLP(n_out=n_out,
                                     n_dim=n_dim,
                                     n_aug=n_aug,
                                     n_dim_x=n_dim_x,
                                     n_dim_y=n_dim_y,
                                     n_chan=n_channels,
                                     n_superbatch=args.n_superbatch,
                                     opt_alg=args.alg,
                                     opt_params=p)
    elif args.model == 'supervised-hdgm':
        X_train = X_train.reshape(-1, n_channels, n_dim_x, n_dim_y)
        X_val = X_val.reshape(-1, n_channels, n_dim_x, n_dim_y)
        model = models.SupervisedHDGM(
            n_out=n_out,
            n_dim=n_dim_x,
            n_chan=n_channels,
            n_superbatch=args.n_superbatch,
            opt_alg=args.alg,
            opt_params=p,
            model='bernoulli' if args.dataset == 'mnist' else 'gaussian')
    else:
        raise ValueError('Invalid model')

    if args.reload: model.load(args.reload)

    logname = '%s-%s-%s-%d-%f-%f' % (args.logname, args.dataset, args.model,
                                     args.n_labeled, args.sup_weight,
                                     args.unsup_weight)

    # train model
    model.fit(X_train,
              y_train,
              X_val,
              y_val,
              n_epoch=args.epochs,
              n_batch=args.n_batch,
              logname=logname)