コード例 #1
0
ファイル: vae_train.py プロジェクト: twoletters/WorldModels
    summary_writer = tf.summary.create_file_writer(tensorboard_dir)
    summary_writer.set_as_default()
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, write_graph=False)
    shuffle_size = 20 * 1000 # only loads ~20 episodes for shuffle windows b/c im poor and don't have much RAM
    ds = tf.data.Dataset.from_generator(ds_gen, output_types=tf.float32, output_shapes=(64, 64, 3))
    ds = ds.shuffle(shuffle_size, reshuffle_each_iteration=True).batch(args.vae_batch_size)
    ds = ds.prefetch(100) # prefetch 100 batches in the buffer #tf.data.experimental.AUTOTUNE)
    vae = CVAE(args=args)
    tensorboard_callback.set_model(vae)
    loss_weights = [1.0, 1.0] # weight both the reconstruction and KL loss the same
    vae.compile(optimizer=vae.optimizer, loss=vae.get_loss(), loss_weights=loss_weights)
    step = 0
    blank_batch = np.zeros([2*args.z_size])
    for i in range(args.vae_num_epoch):
        j = 0
        for x_batch in ds:
            if i == 0 and j == 0:
                vae._set_inputs(x_batch)
            j += 1
            step += 1 
           
            loss = vae.train_on_batch(x=x_batch, y={'reconstruction': x_batch, 'KL': blank_batch}, return_dict=True)
            [tf.summary.scalar(loss_key, loss_val, step=step) for loss_key, loss_val in loss.items()] 
            if j % 100 == 0:
                output_log = 'epoch: {} mb: {}'.format(i, j)
                for loss_key, loss_val in loss.items():
                    output_log += ', {}: {:.4f}'.format(loss_key, loss_val)
                print(output_log)
        print('saving')
        tf.keras.models.save_model(vae, model_save_path, include_optimizer=True, save_format='tf')
コード例 #2
0
ファイル: vae_train.py プロジェクト: hlynurd/wm-norb
    dataset_size = nrecords * 1000 # 10k episodes each 1k steps long
    shuffle_size = min(20, nrecords) * 1000 # only loads 20 episodes for shuffle windows b/c im poor and don't have much RAM
    dataset = create_tf_dataset()
    dataset = dataset.shuffle(shuffle_size, reshuffle_each_iteration=True).batch(args.vae_batch_size)

    vae = CVAE(args=args)
    tensorboard_callback.set_model(vae)

    loss_weights = [1.0, 1.0] # weight both the reconstruction and KL loss the same
    vae.compile(optimizer=vae.optimizer, loss=vae.get_loss(), loss_weights=loss_weights)
    step = 0
    n_mb = dataset_size / args.vae_batch_size 
    #tf.keras.models.save_model(vae, model_save_path, include_optimizer=True, save_format='tf'); print("successfully saved");
    input_spec = tf.TensorSpec([1, 64, 64, 4], tf.int32)
    vae._set_inputs(input_spec)
    tf.keras.models.save_model(vae, model_save_path, include_optimizer=True, save_format='tf'); print("successfully saved");
 
    for i in range(args.vae_num_epoch):
        print('epoch: {}'.format(i))
        j = 0
        print("i is ", i)
        for x_batch, targ_batch, blank_batch in dataset:
            j += 1
            step += 1 
            loss, recon_loss, reg_loss = vae.train_on_batch(x=x_batch, y=[targ_batch, blank_batch])            
            #vae.fit(x=x_batch, y=[targ_batch, blank_batch], verbose=0)
            #loss, recon_loss, reg_loss = vae.fit(x=x_batch, y=[targ_batch, blank_batch])
            with summary_writer.as_default():
                tf.summary.scalar('loss', loss, step=step)
                tf.summary.scalar('reconstruction loss', loss, step=step)