Ejemplo n.º 1
0
def main( dirs, z_size=32, batch_size=100, learning_rate=0.0001, kl_tolerance=0.5, epochs=100, save_model=False, verbose=True, optimizer="Adam" ):

    if save_model:
        model_save_path = "tf_vae"
        if not os.path.exists(model_save_path):
          os.makedirs(model_save_path)

    gen = DriveDataGenerator(dirs, image_size=(64,64), batch_size=batch_size, shuffle=True, max_load=10000, images_only=True )
        
    num_batches = len(gen)

    reset_graph()

    vae = ConvVAE(z_size=z_size,
                  batch_size=batch_size,
                  learning_rate=learning_rate,
                  kl_tolerance=kl_tolerance,
                  is_training=True,
                  reuse=False,
                  gpu_mode=True,
                  optimizer=optimizer)

    early = EarlyStopping(monitor='loss', min_delta=0.1, patience=5, verbose=verbose, mode='auto')
    early.set_model(vae)
    early.on_train_begin()

    best_loss = sys.maxsize

    if verbose:
        print("epoch\tstep\tloss\trecon_loss\tkl_loss")
    for epoch in range(epochs):
        for idx in range(num_batches):
            batch = gen[idx]

            obs = batch.astype(np.float)/255.0

            feed = {vae.x: obs,}

            (train_loss, r_loss, kl_loss, train_step, _) = vae.sess.run([
              vae.loss, vae.r_loss, vae.kl_loss, vae.global_step, vae.train_op
            ], feed)
            
            if train_loss < best_loss:
                best_loss = train_loss

            if save_model:
                if ((train_step+1) % 5000 == 0):
                  vae.save_json("tf_vae/vae.json")
        if verbose:
            print("{} of {}\t{}\t{:.2f}\t{:.2f}\t{:.2f}".format( epoch, epochs, (train_step+1), train_loss, r_loss, kl_loss) )
        gen.on_epoch_end()
        early.on_epoch_end(epoch, logs={"loss": train_loss})
        if vae.stop_training:
            break
    early.on_train_end()


# finished, final model:
    if save_model:
        vae.save_json("tf_vae/vae.json")

    return best_loss
Ejemplo n.º 2
0
    X_train = [np.absolute(X_train[i]) for i in cat
               ] + [X_train[num]]  # + [X_train[env1]] + [X_train[env2]]
    X_val = [np.absolute(X_val[i])
             for i in cat] + [X_val[num]]  # + [X_val[env1]] + [X_val[env2]]

    model = model_NN()
    model.compile(optimizer='Adam',
                  loss='categorical_crossentropy',
                  metrics=[])
    es = EarlyStopping(monitor='val_CRPS',
                       mode='min',
                       restore_best_weights=True,
                       verbose=2,
                       patience=5)
    es.set_model(model)
    metric = Metric(model, [es], [(X_train, y_train), (X_val, y_val)])
    for i in range(1):
        model.fit(X_train, y_train, verbose=False)
    for i in range(1):
        model.fit(X_train, y_train, batch_size=64, verbose=False)
    for i in range(1):
        model.fit(X_train, y_train, batch_size=128, verbose=False)
    for i in range(1):
        model.fit(X_train, y_train, batch_size=256, verbose=False)
    model.fit(X_train,
              y_train,
              callbacks=[metric],
              epochs=100,
              batch_size=1024,
              verbose=False)
Ejemplo n.º 3
0
def train_model(model, data, config, include_tensorboard):
	model_history = History()
	model_history.on_train_begin()
	saver = ModelCheckpoint(full_path(config.model_file()), verbose=1, save_best_only=True, period=1)
	saver.set_model(model)
	early_stopping = EarlyStopping(min_delta=config.min_delta, patience=config.patience, verbose=1)
	early_stopping.set_model(model)
	early_stopping.on_train_begin()
	csv_logger = CSVLogger(full_path(config.csv_log_file()))
	csv_logger.on_train_begin()
	if include_tensorboard:
		tensorborad = TensorBoard(histogram_freq=10, write_images=True)
		tensorborad.set_model(model)
	else:
	 tensorborad = Callback()

	epoch = 0
	stop = False
	while(epoch <= config.max_epochs and stop == False):
		epoch_history = History()
		epoch_history.on_train_begin()
		valid_sizes = []
		train_sizes = []
		print("Epoch:", epoch)
		for dataset in data.datasets:
			print("dataset:", dataset.name)
			model.reset_states()
			dataset.reset_generators()

			valid_sizes.append(dataset.valid_generators[0].size())
			train_sizes.append(dataset.train_generators[0].size())
			fit_history = model.fit_generator(dataset.train_generators[0],
				dataset.train_generators[0].size(), 
				nb_epoch=1, 
				verbose=0, 
				validation_data=dataset.valid_generators[0], 
				nb_val_samples=dataset.valid_generators[0].size())

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

			train_sizes.append(dataset.train_generators[1].size())
			fit_history = model.fit_generator(dataset.train_generators[1],
				dataset.train_generators[1].size(),
				nb_epoch=1, 
				verbose=0)

			epoch_history.on_epoch_end(epoch, last_logs(fit_history))

		epoch_logs = average_logs(epoch_history, train_sizes, valid_sizes)
		model_history.on_epoch_end(epoch, logs=epoch_logs)
		saver.on_epoch_end(epoch, logs=epoch_logs)
		early_stopping.on_epoch_end(epoch, epoch_logs)
		csv_logger.on_epoch_end(epoch, epoch_logs)
		tensorborad.on_epoch_end(epoch, epoch_logs)
		epoch+= 1

		if early_stopping.stopped_epoch > 0:
			stop = True

	early_stopping.on_train_end()
	csv_logger.on_train_end()
	tensorborad.on_train_end({})