def main(args):

    new_model = args.new_model
    N = int(args.N)
    epochs = int(args.epochs)

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise

    try:
        data, N = import_data(N)
    except:
        print('NO DATA FOUND')
        raise

    print('DATA SHAPE = {}'.format(data.shape))

    for epoch in range(epochs):
        print('EPOCH ' + str(epoch))
        vae.train(data)
        vae.save_weights('./vae/weights.h5')
def main(args):
    exec_time = datetime.now().strftime('%Y%m%d-%H%M%S')
    tensorboard = TensorBoard(log_dir=f'log/vae/{exec_time}',
                              update_freq='batch')

    new_model = args.new_model
    epochs = int(args.epochs)
    steps = int(args.steps)

    # instantiate VAE
    vae = VAE() if new_model else load_vae(WEIGHT_FILE_NAME)

    # get training set and validation set generators
    t_gen, v_gen = get_generators(INPUT_DIR_NAME, TV_RATIO)

    # start training!
    vae.train(t_gen,
              v_gen,
              epochs=epochs,
              steps_per_epoch=steps,
              validation_steps=int(steps * TV_RATIO),
              workers=10,
              callbacks=[tensorboard])

    # save model weights
    vae.save_weights(WEIGHT_FILE_NAME)
Exemple #3
0
def main(args):

    new_model = args.new_model
    S = int(args.S)
    N = int(args.N)
    batch = int(args.batch)
    model_name = str(args.model_name)
    print(args.alpha)
    alpha = float(args.alpha)
    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/' + model_name + '/' + model_name +
                            '_weights.h5')

        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise
    else:
        if os.path.isdir('./vae/' + model_name):
            print("A model with this name already exists")
        else:
            os.mkdir('./vae/' + model_name)
            os.mkdir('./vae/' + model_name + '/log/')

    filelist = os.listdir(DIR_NAME)
    filelist = [x for x in filelist if x != '.DS_Store' and x != '.gitignore']
    filelist.sort()

    for i in range(round(float(N - S) / batch)):
        data = import_data(S + i * batch, S + (i + 1) * batch, filelist)
        dataS = []
        dataB = []
        for d in data:
            beta = alpha + np.random.rand() * (1 - alpha)
            dataS.append(
                cv2.resize(crop(d, alpha * beta),
                           dsize=(SCREEN_SIZE_X, SCREEN_SIZE_Y),
                           interpolation=cv2.INTER_CUBIC))
            dataB.append(
                cv2.resize(crop(d, beta),
                           dsize=(SCREEN_SIZE_X, SCREEN_SIZE_Y),
                           interpolation=cv2.INTER_CUBIC))

        dataS = np.asarray(dataS)
        dataB = np.asarray(dataB)

        vae.train(dataS, dataB, model_name
                  )  # uncomment this to train augmenting VAE, simple RNN (2)
        #vae.train(np.vstack([dataS, dataB]), np.vstack([dataS, dataB]), model_name) # uncomment this to train simple VAE, RNN (1)

        vae.save_weights('./vae/' + model_name + '/' + model_name +
                         '_weights.h5')

        print('Imported {} / {}'.format(S + (i + 1) * batch, N))
def main(args):

    num_files = args.num_files
    load_model = args.load_model

    vae = VAE()

    if not load_model=="None":
        try:
            print("Loading model " + load_model)
            vae.set_weights(load_model)
        except:
            print("Either don't set --load_model or ensure " + load_model + " exists")
            raise

    
    vae.train(num_files)
Exemple #5
0
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch
    new_model = args.new_model
    # epochs = args.epochs

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise

    for batch_num in range(start_batch, max_batch + 1):
        print('Building batch {}...'.format(batch_num))
        data = np.load('./data/obs_data_' + str(batch_num) + '.npy')
        data = np.array([item for obs in data for item in obs])
        vae.train(data)
Exemple #6
0
def train_on_drives(args):
    vae = None

    for data, y, cat in loadDataBatches(args.dirs,
                                        skip_actions=True,
                                        max_batch=10000):
        if vae is None:
            input_dim = data[0].shape
            print("Data shape: {}".format(data.shape))
            vae = VAE(input_dim=input_dim)

            if not args.new_model:
                try:
                    vae.set_weights('./vae/weights.h5')
                except:
                    print(
                        "Either set --new_model or ensure ./vae/weights.h5 exists"
                    )
                    raise

        vae.train(data, epochs=100)
Exemple #7
0
def main(args):

    start_batch = args.start_batch
    max_batch = args.max_batch
    new_model = args.new_model
    epochs = args.epochs

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/weights.h5 exists")
            raise

    for i in range(epochs):
        for batch_num in range(start_batch, max_batch + 1):
            print('Building batch {}...'.format(batch_num))
            first_item = True

            for env_name in config.train_envs:
                try:
                    new_data = np.load('./data/obs_data_' + env_name + '_' +
                                       str(batch_num) + '.npy')
                    if first_item:
                        data = new_data
                        first_item = False
                    else:
                        data = np.concatenate([data, new_data])
                    print('Found {}...current data size = {} episodes'.format(
                        env_name, len(data)))
                except:
                    pass

            if first_item == False:  # i.e. data has been found for this batch number
                data = np.array([item for obs in data for item in obs])
                vae.train(data)
            else:
                print('no data found for batch number {}'.format(batch_num))
def main(args):

    new_model = args.new_model
    S = int(args.S)
    N = int(args.N)
    batch = int(args.batch)
    epochs = int(args.epochs)
    model_name = str(args.model_name)

    vae = VAE()

    if not new_model:
        try:
            vae.set_weights('./vae/' + model_name + '_weights.h5')
        except:
            print("Either set --new_model or ensure ./vae/" + model_name +
                  "_weights.h5 exists")
            raise
    elif not os.path.isdir('./vae/' + model_name):
        os.mkdir('./vae/' + model_name)
        os.mkdir('./vae/' + model_name + '/log/')

    filelist = os.listdir(DIR_NAME + model_name)
    filelist = [x for x in filelist if x != '.DS_Store' and x != '.gitignore']
    filelist.sort()
    N = max(N, len(filelist))

    for i in range(int(round(float(N - S) / batch) + 1)):
        dataS, dataB = import_data(S + i * batch, S + (i + 1) * batch,
                                   filelist, model_name)
        for epoch in range(epochs):
            vae.train(
                dataS, dataB, model_name
            )  # uncomment this to train augmenting VAE, simple RNN (2)
            #vae.train(np.vstack([dataS, dataB]), np.vstack([dataS, dataB]), model_name) # uncomment this to train simple VAE, RNN (1)
        vae.save_weights('./vae/' + model_name + '/' + model_name +
                         '_weights.h5')

        print('Imported {} / {}'.format(S + (i + 1) * batch, N))
def main(args):

  start_batch = args.start_batch
  max_batch = args.max_batch
  new_model = args.new_model

  vae = VAE()

  

  if not new_model:
    try:
      vae.set_weights('./vae/weights.h5')
    except:
      print("Either set --new_model or ensure ./vae/weights.h5 exists")
      raise

  for batch_num in range(start_batch, max_batch + 1):
    print('Building batch {}...'.format(batch_num))
    first_item = True

    for env_name in config.train_envs:
      try:
        new_data = np.load('./data/obs_data_' + env_name + '_' + str(batch_num) + '.npy')
        if first_item:
          data = new_data
          first_item = False
        else:
          data = np.concatenate([data, new_data])
        print('Found {}...current data size = {} episodes'.format(env_name, len(data)))
      except:
        pass

    if first_item == False: # i.e. data has been found for this batch number
      data = np.array([item for obs in data for item in obs])
      vae.train(data)
    else:
      print('no data found for batch number {}'.format(batch_num))
Exemple #10
0
vae.model.summary()
#vae = VAE()

# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

x_train = x_train.reshape(x_train.shape + (1, ))
x_test = x_test.reshape(x_test.shape + (1, ))

#x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
#x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

vae.train(x_train)

# build a model to project inputs on the latent space
encoder = Model(x, z_mean)

# display a 2D plot of the digit classes in the latent space
x_test_encoded = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test)
plt.colorbar()
plt.show()

# build a digit generator that can sample from the learned distribution
decoder_input = Input(shape=(latent_dim, ))
_h_decoded = decoder_h(decoder_input)
_x_decoded_mean = decoder_mean(_h_decoded)