示例#1
0
def main():
    args = PARSER.parse_args()

    data_path = get_path(args, "record")
    model_save_path = get_path(args, "tf_vae", create=True)

    ensure_validation_split(data_path)
    _n_train, _avg_frames, mean, var = analyse_dataset(data_path)
    if args.normalize_images:
        train_data, val_data = create_tf_dataset(data_path, args.z_size, True, mean, var)
    else:
        train_data, val_data = create_tf_dataset(data_path, args.z_size)

    shuffle_size = 5 * 1000  # Roughly 20 full episodes for shuffle windows, more increases RAM usage
    train_data = train_data.shuffle(shuffle_size, reshuffle_each_iteration=True).batch(args.vae_batch_size).prefetch(2)
    val_data = val_data.batch(args.vae_batch_size).prefetch(2)

    current_time = datetime.now().strftime("%Y%m%d-%H%M%S")
    tensorboard_dir = model_save_path / "tensorboard" / current_time

    vae = CVAE(args=args)
    vae.compile(optimizer=vae.optimizer, loss=vae.get_loss())
    vae.fit(train_data, validation_data=val_data, epochs=args.vae_num_epoch, callbacks=[
        tf.keras.callbacks.TensorBoard(log_dir=str(tensorboard_dir), update_freq=50, histogram_freq=1),
        LogImage(str(tensorboard_dir), val_data),
        tf.keras.callbacks.ModelCheckpoint(str(model_save_path / "ckpt-e{epoch:02d}"), verbose=1),
    ])
    vae.save(str(model_save_path))
示例#2
0
if __name__ == "__main__": 
    model_save_path = "results/{}/{}/tf_vae".format(args.exp_name, args.env_name)
    if not os.path.exists(model_save_path):
        os.makedirs(model_save_path)
    tensorboard_dir = os.path.join(model_save_path, 'tensorboard')
    summary_writer = tf.summary.create_file_writer(tensorboard_dir)
    summary_writer.set_as_default()
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, write_graph=False)
    shuffle_size = 20 * 1000 # only loads ~20 episodes for shuffle windows b/c im poor and don't have much RAM
    ds = tf.data.Dataset.from_generator(ds_gen, output_types=tf.float32, output_shapes=(64, 64, 3))
    ds = ds.shuffle(shuffle_size, reshuffle_each_iteration=True).batch(args.vae_batch_size)
    ds = ds.prefetch(100) # prefetch 100 batches in the buffer #tf.data.experimental.AUTOTUNE)
    vae = CVAE(args=args)
    tensorboard_callback.set_model(vae)
    loss_weights = [1.0, 1.0] # weight both the reconstruction and KL loss the same
    vae.compile(optimizer=vae.optimizer, loss=vae.get_loss(), loss_weights=loss_weights)
    step = 0
    blank_batch = np.zeros([2*args.z_size])
    for i in range(args.vae_num_epoch):
        j = 0
        for x_batch in ds:
            if i == 0 and j == 0:
                vae._set_inputs(x_batch)
            j += 1
            step += 1 
           
            loss = vae.train_on_batch(x=x_batch, y={'reconstruction': x_batch, 'KL': blank_batch}, return_dict=True)
            [tf.summary.scalar(loss_key, loss_val, step=step) for loss_key, loss_val in loss.items()] 
            if j % 100 == 0:
                output_log = 'epoch: {} mb: {}'.format(i, j)
                for loss_key, loss_val in loss.items():