A terrain synthesizer that can run on a smartphone, generate 1, 000s of high-quality samples each second, and can be controlled through text or sketches would be able to create infinitely recursive computer simulations of hyper-realistic digital landscapes based on a user's experience and preferences.
This project is a step towards this vision.
Creating visually appealing and realistic terrains is notoriously difficult and time-consuming, yet it is detrimental to modern Computer Graphics and Video Game Development.
Digital terrain creation is plagued by several problems. Skilled labor is expensive and most game studios cannot afford to pay digital artists to create terrains by hand.
The algorithmic approach, called procedural generation, has proven to be efficient in many scenarios. Algorithms such as the Perlin noise can efficiently create digital elevation maps, however, these methods oftentimes produce bland, low-fidelity samples that do not mimic the geomorphic qualities of terrains. These samples also look quite similar to each other, so creating larger terrains by tiling these samples could produce low-quality terrains.
The recent advances in Generative AI made good-quality image generation possible. Learning from large sets of data, architectures like Generative Adversarial Networks (GANs) and Diffusion Models have the potential to generate high-quality terrains, however, GANs are notoriously hard to train because of the adversarial setup, and diffusion models require a lot of evaluation steps to produce sufficient samples, with some variants requiring thousands of steps to generate a single batch of images.
Consistency ModelsThe success of diffusion models led researchers to develop techniques to reduce the computational costs of diffusion inference. The EDM models and paper by Karras et al. highlighted several areas where previous diffusion architectures could be improved. They use the Stochastic and Ordinary Differential Equation (SDE, ODE) formalisms from Song et al. to describe the noising and denoising processes as forward- and reverse-time diffusion processes, resulting in state-of-the-art image generation results, using fewer steps than models of comparable sample quality.
In 2023, Song et al. introduced a new family of generative models called Consistency Models. Building on the EDM architecture, the method defines a constraint called the consistency function. Consistency functions require the models to map each point on the trajectory of the ODE to the same value. The authors also add a constraint to the consistency function so that its value at t=0 (no added noise) equals x0. This way, the ODE trajectories are separated for every unique data point.
Consistency Models are defined in two forms: Consistency Distillation (CD) and Consistency Training (CT). With CD, the model is trained to match the outputs of a readily available diffusion model, and with CT, the Consistency Model is trained from scratch. The work in Improved Techniques for Training Consistency Models (ICT) improved on the previous consistency training setup and the authors outperformed Consistency Distillation and many few-step diffusion approaches.
# sampling from a consistency model
def sample_single_step(key, denoising_fn, denoising_params,
shape, sigma_data, sigma_min,
sigma_max, context) -> jax.Array:
# generate noise
xT = random.normal(key, shape) * sigma_max
#turn the start time into a batch-sized vector
sigmas = sigma_max * jnp.ones(shape[:1])
_, sample = consistency_fn(xT, context, sigmas, sigma_data, sigma_min,
denoising_fn, denoising_params, train=False)
sample = jnp.clip(sample, -1, 1)
return sample
Fast sampling is critical for numerous digital terrain generation settings, especially for open-world terrains. CMs have a remarkable advantage compared to classical Diffusion Models: they can generate samples in a single forward evaluation step. This constitutes a 30x-1000x speedup compared to classical diffusion variants, meaning Consistency Models have the potential to generate realistic digital terrains efficiently.
Network ArchitectureConsistency models need neural networks to approximate the accumulated noise between timesteps. The most popular choice for generative models is the so-called U-Net. This architecture has a contracting and an expanding path, where the contracting path shrinks the data dimension and extracts feature maps, and the expanding path increases the data dimension to assemble the transformed data.
The U-Net used for this project is similar to the architecture used for training the models on CIFAR-10 in Song et al., although the two models differ in some aspects.
In the contracting path, the model downsamples the 64x64 input data to 8x8 feature maps. The layers of the U-Net consist of ResNet blocks and attention blocks. The model has transformer blocks at the 32x32 and 8x8 dimensions, where attention blocks are sandwiched between ResNet blocks.
Compared to the ICT model used for CIFAR-10, the following changes were implemented.
Time embeddings: Instead of using Fourier embeddings, the U-Net represents the noise scale with sinusoidal embeddings from Vaswani et al.. Instead of16-dimension positional embeddings, the terrain network uses 768 positional features (see next section for reasoning).
Networkblocks: The new network uses 3 attention and res blocks instead of the 4 in the original architecture and does not implement feature pyramids. Instead of self-attention, the U-Net uses cross-attention (see next section).
Text-guided GenerationIn most cases, guiding the terrain creation process is one of the most important criteria for terrain synthesizers. Recently, numerous guidance methods emerged to generate samples conditioned on heterogeneous inputs.
The most straightforward way a layman could control a terrain generation model is by passing natural language instructions to the network, a problem in Classifier-free Guidance. To this end, the present project uses CLIP, an approach to embed images and their textual descriptions into the same latent space. By putting the textual descriptions through a CLIP model, we can use its latent representation to generate elevation maps conditionally.
The embedding of a terrain description is cast to a 77x768 matrix. The specified U-Net incorporates CLIP embeddings through Cross-Attention: the layer input serves as the query (Q) matrix and the context embedding constitutes the key (K) and value (V) matrices. The CLIP embeddings are passed to every attention block in the network.
When no textual description is available, the network receives the CLIP embedding of an empty string.
Dataset CurationNowadays, many Digital Elevation Map (DEM) datasets are publicly accessible. These datasets vary in quality, resolution, and land coverage. These datasets can be leveraged to train Machine Learning models capable of generating high-fidelity terrains.
Elevation dataset: This work uses the NASADEM Digital Elevation Model, a collection of elevation maps that are interferograms made from multiple radar images. The dataset contains approximately 14, 000 samples of 3601x3601 GeoTIFF elevation maps, covering a large portion of Earth.
A subset of 4, 096 samples was randomly selected and then segmented into 512x512 pixel slices, resulting in approximately 400K slices. After dropping elevation maps with zero-only elevation, the remaining 360K+ samples were ranked in decreasing order according to their Shannon entropy, and the top 100K slices were selected to keep the most "interesting" slices. Subsequently, these slices were downscaled to 64x64 pixel files for computational efficiency.
Terraindescriptions: From the curated terrain dataset, 20K elevation maps were selected and transformed into colored topographic maps using a 256-level colormap from this repository. Based on the longitudinal and latitudinal information from the original NASADEM samples, information such as the country and region of the elevation map was retrieved through reverse search with the GeoPy library, using the OpenStreetMap Nominatim API. The images were fed into GPT-4o Mini to generate textual descriptions of the terrains, resulting in approximately 18K natural language instructions. GPT-4o generated the descriptions per the following prompt:
You are tasked with captioning geographical maps to distill their scenic properties. {info_str}
Provide a description of the terrain visible on the image with vivid geomorphic detail. Avoid specifying colors, rather talk about the geomorphic elements and elevation.Remember: the goal is to specify as many artifacts as possible with terminology from geography. The geographical names in the captions should be translated to English. Do not use more than two sentences.
Code: The model and the training code were mainly written using JAX and Flax. The implementation of the Consistency Model components mostly relied on the original research paper and the official implementation for CIFAR-10. When not mentioned explicitly, the details of the network architecture and the parameters correspond to the ones found in the ICT paper.
EMA: While the original Consistency Training used an exponential moving average (EMA) to update the model parameters, this was avoided in this project due to the time overhead of EMA updates.
Dataloaders: The elevation files are stored in GeoTIFF (.tif) format and the context embeddings are stored as NumPy tensors. The elevation data and context embeddings are read by a custom TFDS dataset. The source code also has PyTorch datasets for reading from zip archives and directories, but these were not used for the final training run due to their slow performance (JAX is multiprocessed, thus the dataloaders can only use one worker).
Optimizer: The model was trained using the radam
optimizer with a learning rate of 0.0002 and the maximum number of training steps set at 800, 000.
The elevation data was collected and preprocessed in Google Colab due to the high amount of RAM available. Terrain descriptions were likewise generated in Colab due to convenience.
ResultsUnfortunately, training a 64x64 consistency model is very computationally demanding. Due to the contest's time constraints and errors that occurred during model training, the latest available terrain consistency checkpoint is at 255, 000 training steps.
Consistency training has an odd loss curve due to the timestep schedule, thus finishing the training run will hugely increase the model performance.
Even after completing just 32% of the run, a subset of the generated landscapes clearly show emerging fractal-like properties and the smoothed results resemble real terrains. The trained model achieves a FID-score of 98.8 compared to the untrained baseline 180.4.
Below are some qualitative samples from the latest (255, 000) model checkpoint. The plots were made with the code from this article.
Apart from finishing the training runs and evaluating text-guided generation, the architecture and workflow outlined in this project could be improved in several ways.
Dataset collection: Training a terrain generation model with significantly more data could impact the sample quality tremendously. The greatest limitation of the current project is the dataset size and the sample quality. With under 20, 000 textual descriptions, the network is unlikely to generate diverse terrains or respond well to rare instructions. The prompts used with the chosen text-to-image model also impact the results greatly.
When filtering elevation maps, ranking by entropy means that the model will not see many "boring" samples, such as plains, and thus it will be biased. A more advanced filtering method could improve sample diversity.
Autoencoders: The current architecture denoises 64x64 images, though, in many use cases, only higher-resolution terrains are feasible. Using an encoder and decoder architecture - such as the one in Stable Diffusion - to represent high-dimensional images or elevation maps as lower-dimensional latent vectors (and applying the consistency function to these victors) could enhance sample quality while preserving generation speed.
Guidance: By guiding terrain generation by means other than text, experienced users could obtain better samples and have more control over the results. Implementing methods such as classifier and sketch guidance could improve the neural terrain generation workflow.
Inverse problems: Adding features for various image processing tasks is crucial in many terrain synthesis settings. Problems such as image super-resolution and inpainting are essential for creating high-quality, tileable terrains, and could be implemented effortlessly based on previous works.
NotesThe training script is accessible and the results are reproducible from the attached source code or the GitHub repository of the project. The latest checkpoint and the full list of Python dependencies used directly in the project can be found in the setup file of the attached source code.
- Running JAX on a ROCm device requires a custom installation that can be done with the ROCm Docker image or by building from source. For training the model for the project, the ROCm JAX Docker image was used.
- With installations of Keras 3, the jax-fid module encounters an error. This can be circumvented by replacing the original
ImageDataGenerator
import injax_fid/fid.py
withfrom keras.src.legacy.preprocessing.image import ImageDataGenerator
. - The data archives take up quite a lot of space. As an alternative to unpacking the zip files, using
ZippedTerrainDataset
could reduce the required disk space, though this will result in a significant drop in training speed. Using TFDS, the Radeon device learns at approximately 2 to 3 iterations per second, while using a PyTorch dataloader results in a speed of less than 0.1 iterations per second.
Cover photo: Sir John Murray's map of the Indian Ocean, Chart 1C, accompanying the Summary of Results of the Challenger Expedition, 1895 via Unsplash
Comments