Fourier Neural Operator Outperforms U-Net in Turbulence Prediction with Efficient Computation and Accurate Results
The Fourier Neural Operator (FNO) is a groundbreaking approach that leverages the power of neural networks to approximate physical models, particularly in scenarios where traditional numerical methods are computationally expensive. One common limitation of numerical models, like those used for solving the Navier-Stokes equations, is the necessity to compute all parameters even if the focus is on a single one. Additionally, maintaining high spatial and temporal resolution to ensure computational stability, known as the Courant-Friedrichs-Lewy (CFL) condition, further exacerbates the computational demands. FNO addresses these issues by efficiently transforming and filtering data in the frequency domain, allowing for faster and more accurate predictions. Basic Structure of FNO Input and Output Dimensions The FNO model processes data in 5 dimensions: Batch (B), Timesteps (T), Channels (C), Height (H), and Width (W). For a typical prediction, the input might be reshaped to fit these dimensions, and the output is often reduced to 4 dimensions (B, C, H, W) when predicting a single timestep for a single parameter. Feature Extractor The feature extractor layer ensures that the input tensor's shape matches the internal structure of the FNO blocks. It typically involves reshaping and compressing the input data to a specified "hidden depth," which is controlled by a hyperparameter. Fourier Convolution Block The core of the FNO is the Fourier Convolution Block (SpectralConv2d_fast), which performs the following steps: 1. Discrete Fourier Transform (FFT): Transforms the input data from physical to frequency space using torch.fft.rfft2. 2. Filtering: Applies weights to specific Fourier modes to filter out noise and emphasize relevant features. This is done through element-wise multiplication of the Fourier coefficients with learned weights. 3. Inverse FFT (IFFT): Converts the filtered data back to the physical space using torch.fft.irfft2. MLP Layers Multilayer Perceptrons (MLPs) are used within the FNO blocks to refine the filtered data. These layers involve convolutions and activation functions to enhance the quality of the predictions. Implementation Example: Turbulence Prediction Dataset Loading The dataset is loaded using a custom dataset class, DensityTurbulenceDataset, which reads and processes .npz files containing turbulence data. The dataset includes sequences of past and future timesteps for parameters such as density, velocity, and pressure. Model Definition The FNO model is defined in two main classes: 1. SpectralConv2d_fast: Handles the Fourier convolution process. 2. SimpleBlock2d: Combines multiple SpectralConv2d_fast blocks with norm and MLP layers. The Net2d class wraps the SimpleBlock2d model and provides a forward pass. Training Loop The training loop involves: 1. Data normalization: Using pre-calculated mean and standard deviation. 2. Unrolling prediction: Predicting future timesteps one by one. 3. Loss calculation: Combining L2 and H1 losses to evaluate the prediction accuracy. 4. Backpropagation and optimization: Updating the model parameters based on the calculated loss. Testing and Results After training, the model is tested using the following steps: 1. Load the trained model: From the saved checkpoint. 2. Normalization: Apply the same mean and standard deviation used during training. 3. Unrolling prediction: Similar to the training loop, but without gradient updates. 4. Visualization: Compare the ground truth and predicted density fields over time. Comparison with U-Net When compared to a U-Net model, the FNO version demonstrates superior performance in capturing the movement patterns of turbulence. While the U-Net model predicts the movement of "bubbles" too quickly, the FNO model maintains a more accurate representation of the fluid dynamics. This is largely due to FNO's ability to efficiently handle and filter data in the frequency domain, reducing computational overhead and improving predictive accuracy. Code Snippets SpectralConv2d_fast ```python import torch import torch.nn as nn import torch.fft class SpectralConv2d_fast(nn.Module): def init(self, in_channels, out_channels, modes1, modes2): super(SpectralConv2d_fast, self).init() self.in_channels = in_channels self.out_channels = out_channels self.modes1 = modes1 self.modes2 = modes2 self.scale = (1 / (in_channels * out_channels)) self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) def forward(self, x): batchsize = x.shape[0] x_ft = torch.fft.rfft2(x, norm='ortho') out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device) out_ft[:, :, :self.modes1, :self.modes2] = compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) out_ft[:, :, -self.modes1:, :self.modes2] = compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)), norm='ortho') return x def compl_mul2d(input, weights): return torch.einsum("bixy,ioxy->boxy", input, weights) ``` SimpleBlock2d ```python import torch import torch.nn as nn import torch.nn.functional as F from functools import reduce import operator class MLP(nn.Module): def init(self, in_channels, out_channels, mid_channels): super(MLP, self).init() self.mlp1 = nn.Conv2d(in_channels, mid_channels, 1) self.mlp2 = nn.Conv2d(mid_channels, out_channels, 1) def forward(self, x): x = self.mlp1(x) x = F.gelu(x) x = self.mlp2(x) return x class SimpleBlock2d(nn.Module): def init(self, modes1, modes2, width): super(SimpleBlock2d, self).init() self.modes1 = modes1 self.modes2 = modes2 self.width = width self.padding = 8 self.p = nn.Linear(12, width) self.conv0 = SpectralConv2d_fast(width, width, modes1, modes2) self.conv1 = SpectralConv2d_fast(width, width, modes1, modes2) self.conv2 = SpectralConv2d_fast(width, width, modes1, modes2) self.conv3 = SpectralConv2d_fast(width, width, modes1, modes2) self.mlp0 = MLP(width, width, width) self.mlp1 = MLP(width, width, width) self.mlp2 = MLP(width, width, width) self.mlp3 = MLP(width, width, width) self.w0 = nn.Conv2d(width, width, 1) self.w1 = nn.Conv2d(width, width, 1) self.w2 = nn.Conv2d(width, width, 1) self.w3 = nn.Conv2d(width, width, 1) self.norm = nn.InstanceNorm2d(width) self.q = MLP(width, 1, width * 4) def forward(self, x): x = self.p(x) x = x.permute(0, 4, 1, 2, 3) x = x.view(x.shape[0], -1, x.shape[3], x.shape[4]) x1 = self.norm(self.conv0(self.norm(x))) x1 = self.mlp0(x1) x2 = self.w0(x) x = x1 + x2 + x x = F.gelu(x) x1 = self.norm(self.conv1(self.norm(x))) x1 = self.mlp1(x1) x2 = self.w1(x) x = x1 + x2 + x x = F.gelu(x) x1 = self.norm(self.conv2(self.norm(x))) x1 = self.mlp2(x1) x2 = self.w2(x) x = x1 + x2 + x x = F.gelu(x) x1 = self.norm(self.conv3(self.norm(x))) x1 = self.mlp3(x1) x2 = self.w3(x) x = x1 + x2 + x x = F.gelu(x) x = self.q(x) x = x.permute(0, 2, 3, 1) return x ``` Net2d and Training ```python class Net2d(nn.Module): def init(self, modes, width): super(Net2d, self).init() self.conv1 = SimpleBlock2d(modes, modes, width) def forward(self, x): x = self.conv1(x) return x def count_params(self): c = 0 for p in self.parameters(): c += reduce(operator.mul, list(p.size())) return c --- Training Configuration --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") SEQ_LEN_FEATURE = 10 SEQ_LEN_ROLLOUT = 25 WIDTH_IMG = 128 HEIGHT_IMG = 64 BATCH_SIZE = 8 TOTAL_EPOCHS = 300 LEARNING_RATE = 0.001 SCHEDULER_STEP = 100 SCHEDULER_GAMMA = 0.5 CHECKPOINT_PATH = 'best_mgfnol2losshlossneuralop_turbulence_unroll.pth' MEAN_STD_PATH = 'density_stats.npz' --- Load Mean and Std --- mean_std = np.load(MEAN_STD_PATH) mean_device = torch.from_numpy(mean_std['mean']).float().to(device) std_device = torch.from_numpy(mean_std['std']).float().to(device) --- Load Dataset --- dataset = DensityTurbulenceDataset(dataset_dir, seq_len_total=SEQ_LEN_FEATURE + SEQ_LEN_ROLLOUT) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) --- Initialize Model and Optimizer --- model = Net2d(16, 20).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=SCHEDULER_STEP, gamma=SCHEDULER_GAMMA) l2loss = LpLoss(d=2, p=2, reduce_dims=(0, 1)) h1_loss = H1Loss(d=2, reduce_dims=(0, 1)) --- Checkpoint Loading --- if os.path.exists(CHECKPOINT_PATH): checkpoint = torch.load(CHECKPOINT_PATH, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) last_epoch = checkpoint['epoch'] best_loss = checkpoint['loss'] print(f"Loaded checkpoint from epoch {last_epoch}, best loss: {best_loss:.6f}") --- Training Loop --- model.train() for epoch in range(TOTAL_EPOCHS): if epoch <= last_epoch: continue total_loss = 0 loop = tqdm(loader, desc=f"Epoch {epoch+1}/{TOTAL_EPOCHS}") for inputs_seq_raw, targets_seq_raw in loop: inputs_seq_raw = inputs_seq_raw.to(device) targets_seq_raw = targets_seq_raw.to(device) inputs_seq_norm = (inputs_seq_raw - mean_device) / std_device targets_seq_norm = (targets_seq_raw - mean_device) / std_device optimizer.zero_grad() current_input_sequence = inputs_seq_norm predicted_density_norm = [] ground_truth_density_norm = [] for r_step in range(SEQ_LEN_ROLLOUT): x, y = create_embd_dim(len(current_input_sequence)) x = x.squeeze(1) y = y.squeeze(1) used_input_sequence = torch.cat((current_input_sequence, x, y), dim=1).unsqueeze(2) used_input_sequence = used_input_sequence.permute(0, 2, 3, 4, 1) pred_norm = model(used_input_sequence) pred_norm = pred_norm.squeeze() predicted_density_norm.append(pred_norm) current_gt_frame_norm = targets_seq_norm[:, r_step:r_step+1, :, :] ground_truth_density_norm.append(current_gt_frame_norm) next_input = pred_norm.unsqueeze(0).unsqueeze(0) current_input_sequence = torch.cat([ current_input_sequence[:, 1:, :, :], next_input ], dim=1) predicted_density_norm = torch.cat(predicted_density_norm, dim=1) ground_truth_density_norm = torch.cat(ground_truth_density_norm, dim=1) loss = l2loss(predicted_density_norm, ground_truth_density_norm) + 0.2 * h1_loss(predicted_density_norm, ground_truth_density_norm) loss.backward() optimizer.step() total_loss += loss.item() loop.set_postfix(loss=loss.item()) scheduler.step() avg_loss = total_loss / len(loader) print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.6f}") if avg_loss < best_loss: best_loss = avg_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': best_loss }, CHECKPOINT_PATH) print(f"Saved new best model at epoch {epoch+1} with loss {best_loss:.6f}") --- Final Save --- torch.save({ 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': total_loss / len(loader) }, 'final_mgfnol2losshlossneuralop_turbulence_unroll.pth') print("Training complete.") ``` Evaluation and Insights The Fourier Neural Operator (FNO) has shown significant promise in the field of fluid dynamics, particularly for turbulence prediction. By leveraging Fourier transforms, FNO can efficiently handle high-dimensional data and capture global characteristics of the flow, which are crucial for accurate predictions. Compared to traditional numerical models and even other deep learning models like U-Net, FNO offers a balance of computational efficiency and predictive accuracy, making it a versatile tool for complex physical simulations. The model's ability to maintain detailed spatial and temporal resolution while avoiding the computational bottlenecks associated with traditional methods is particularly noteworthy. This makes FNO a valuable approach in scenarios where real-time or near-real-time predictions are required, such as weather forecasting and aerodynamics. Industry Insights and Company Profiles The development and application of FNO reflect the broader trend in AI towards creating more efficient and scalable solutions for physical modeling. Companies like NVIDIA and Google have also explored similar techniques, emphasizing the importance of frequency-domain processing in advancing the capabilities of deep learning models. The success of FNO in turbulence prediction highlights its potential for use in various industries, including aerospace, automotive, and meteorology, where accurate and fast simulation of fluid dynamics is crucial.