HyperAI초신경
Back to Headlines

파이토치로 푸리에 신경망으로 난류 예측하기

3일 전

최근, 신경망이 수치 모델의 기능을 모방하면서도 그 제약을 극복하는 가능성을 탐구하는 데 큰 관심을 두고 있습니다. 수치 모델은 예측에 필요한 모든 파라미터를 계산해야 한다는 한계가 있습니다. 예를 들어, 나비어-스톡스 방정식에서 미래의 몇 단계를 예측하고자 할 때, 압력뿐만 아니라 밀도, 속도 등의 다른 파라미터도 모두 계산해야 합니다. 이는 작은 영역에서도 상당한 컴퓨팅 자원을 필요로 합니다. 또한, 시간과 공간 해상도를 일치시켜야 하는 Courant-Friedrichs-Lewy (CFL) 조건으로 인해 미래를 더 멀리 예측할수록 예측의 세부 사항이 손실되는 문제도 있습니다. 이러한 제약을 해결하기 위해 Fourier Neural Operator (FNO)를 활용할 수 있습니다. 기본 FNO FNO의 기본 워크플로우는 다음과 같습니다. 이미지 데이터 처리 시, 입력 데이터의 차원은 배치(B) x 시간 단계(T) x 채널(C) x 높이(H) x 너비(W)로 구성되며, 출력 데이터 역시 같은 형식입니다. 하지만 일반적으로 1개의 시간 단계만 예측하므로 출력 데이터의 차원은 4차원 (B, C, H, W)로 줄어들고, 1개의 파라미터만 처리하면 (B, H, W)로 감소합니다. 특징 추출기(FEATURE EXTRACTOR)는 FNO 블록에 입력된 데이터의 형태를 맞추기 위한 레이어입니다. 이 레이어는 총 깊이(시간 단계 x 채널)를 'hidden depth' 하이퍼파라미터 값으로 변경합니다. FNO 블록에서는 먼저 입력 데이터를 물리 공간에서 주파수 공간으로 변환하기 위해 Discrete Fourier Transform (DFT)를 적용합니다. 이 과정에서 가장 효율적인 알고리즘인 Fast Fourier Transform (FFT)를 사용합니다. FFT를 통해 필터링 단계에서 노이즈를 제거하고, 이 결과를 다시 역 FFT를 통해 원래의 물리 공간으로 복원합니다. FNO의 필터링 과정은 단순히 고주파 성분을 제거하는 것이 아니라, 저주파 성분의 중요성을 모델의 가중치로 조정하는 것입니다. 훈련 중 특정 주파수가 예측에 영향을 미치지 않는다면, 해당 주파수의 가중치는 낮아지거나 0이 됩니다. FNO 코드 FNO의 핵심 빌딩 블록은 SpectralConv2d_fast 클래스입니다. 이 클래스는 FFT를 사용하여 입력 데이터를 주파수 공간으로 변환하고, 필터링 단계를 거쳐 다시 물리 공간으로 복원합니다. ```python class SpectralConv2d_fast(nn.Module): def init(self, in_channels, out_channels, modes1, modes2): super(SpectralConv2d_fast, self).init() self.in_channels = in_channels self.out_channels = out_channels self.modes1 = modes1 # Fourier modes to multiply, at most floor(N/2) + 1 self.modes2 = modes2 self.scale = (1 / (in_channels * out_channels)) self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) self.weights2 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, self.modes2, dtype=torch.cfloat)) def forward(self, x): batchsize = x.shape[0] x_ft = torch.fft.rfft2(x, norm='ortho') out_ft = torch.zeros(batchsize, self.out_channels, x.size(-2), x.size(-1) // 2 + 1, dtype=torch.cfloat, device=x.device) out_ft[:, :, :self.modes1, :self.modes2] = compl_mul2d(x_ft[:, :, :self.modes1, :self.modes2], self.weights1) out_ft[:, :, -self.modes1:, :self.modes2] = compl_mul2d(x_ft[:, :, -self.modes1:, :self.modes2], self.weights2) x = torch.fft.irfft2(out_ft, s=(x.size(-2), x.size(-1)), norm='ortho') return x ``` 난류 예측 사례 FNO 모델의 성능을 평가하기 위해 난류 예측 사례를 사용했습니다. 데이터셋은 이 저장소에서 다운로드받았습니다. 데이터셋에는 밀도, 압력, 속도 등의 난류 파라미터가 포함되어 있으며, 총 1000개의 시간 단계가 있습니다. 이 테스트 케이스에서는 이전 밀도 시퀀스 10개를 사용하여 미래의 밀도를 예측했습니다. 데이터셋 클래스 ```python class DensityTurbulenceDataset(Dataset): def init(self, dataset_dir, seq_len_total): self.file_list = sorted(glob(os.path.join(dataset_dir, "*.npz"))) self.seq_len_total = seq_len_total self.data_cache = [] print(f"Loading dataset from {dataset_dir} into memory...") for f_path in tqdm(self.file_list, desc=f"Loading {os.path.basename(dataset_dir)}"): self.data_cache.append(np.load(f_path)['arr_0']) print(f"Dataset from {dataset_dir} loaded. Total files: {len(self.data_cache)}") def __len__(self): return len(self.file_list) - self.seq_len_total + 1 def __getitem__(self, idx): sequence_data = np.stack([self.data_cache[idx + i] for i in range(self.seq_len_total)]) inputs_full_sequence = sequence_data[:SEQ_LEN_FEATURE, :, :, :] targets_density_sequence = sequence_data[SEQ_LEN_FEATURE:, 0, :, :] return torch.from_numpy(inputs_full_sequence).float(), torch.from_numpy(targets_density_sequence).float() ``` 훈련 루프 ```python dataset = DensityTurbulenceDataset(dataset_dir, seq_len_total=SEQ_LEN_FEATURE + SEQ_LEN_ROLLOUT) loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) mean_std = np.load('density_stats.npz') mean = mean_std['mean'] std = mean_std['std'] model = Net2d(16, 20).to(device) mean_device = torch.from_numpy(mean).float().to(device) std_device = torch.from_numpy(std).float().to(device) learning_rate = 0.001 scheduler_step = 100 scheduler_gamma = 0.5 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step, gamma=scheduler_gamma) l2loss = LpLoss(d=2, p=2, reduce_dims=(0,1)) h1_loss = H1Loss(d=2, reduce_dims=(0,1)) if os.path.exists(CHECKPOINT_TO_LOAD): checkpoint = torch.load(CHECKPOINT_TO_LOAD, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) last_epoch = checkpoint['epoch'] best_loss = checkpoint['loss'] print(f"Loaded checkpoint from epoch {last_epoch}, best loss: {best_loss:.6f}") model.train() for epoch in range(TOTAL_EPOCHS): if epoch <= last_epoch: continue total_loss = 0 loop = tqdm(loader, desc=f"Epoch {epoch+1}/{TOTAL_EPOCHS}") for inputs_seq_raw, targets_seq_raw in loop: inputs_seq_raw = inputs_seq_raw.to(device) targets_seq_raw = targets_seq_raw.to(device) inputs_seq_norm = (inputs_seq_raw - mean_device) / std_device targets_seq_norm = (targets_seq_raw - mean_device) / std_device optimizer.zero_grad() current_input_sequence = inputs_seq_norm predicted_density_norm = [] ground_truth_density_norm = [] for r_step in range(SEQ_LEN_ROLLOUT): x, y = create_embd_dim(len(current_input_sequence)) used_input_sequence = torch.cat((current_input_sequence, x, y), dim=1).unsqueeze(2) used_input_sequence = used_input_sequence.permute(0, 2, 3, 4, 1) pred_norm = model(used_input_sequence) pred_norm = pred_norm.squeeze() predicted_density_norm.append(pred_norm) current_gt_frame_norm = targets_seq_norm[:, r_step:r_step+1, :, :] ground_truth_density_norm.append(current_gt_frame_norm) prepared_next_input = pred_norm.unsqueeze(1) current_input_sequence = torch.cat([ current_input_sequence[:, 1:, :, :], prepared_next_input ], dim=1) predicted_density_norm = torch.cat(predicted_density_norm, dim=1) ground_truth_density_norm = torch.cat(ground_truth_density_norm, dim=1) loss = l2loss(predicted_density_norm, ground_truth_density_norm) + 0.2 * h1_loss(predicted_density_norm, ground_truth_density_norm) loss.backward() optimizer.step() total_loss += loss.item() loop.set_postfix(loss=loss.item()) scheduler.step() avg_loss = total_loss / len(loader) print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.6f}") if avg_loss < best_loss: best_loss = avg_loss torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': best_loss }, CHECKPOINT_TO_LOAD) print(f"Saved new best model at epoch {epoch+1} with loss {best_loss:.6f}") ``` 테스트 코드 훈련된 모델을 로드하고, 테스트 데이터셋을 사용하여 예측 결과를 생성합니다. ```python file_list = sorted(glob(os.path.join(dataset_dir, "*.npz")))[-SEQ_LEN_FEATURE:] seqdata2d = [] for iter_f in range(len(file_list)): single_data = np.load(file_list[iter_f])['arr_0'] seqdata2d.append(single_data) seqdata2d = np.array(seqdata2d) feature = torch.from_numpy(seqdata2d).float().to(device) feature = (feature - mean_device) / std_device feature = feature.squeeze(1).unsqueeze(0) file_list_test = sorted(glob(os.path.join(dataset_test_dir, "*.npz"))) seqtruth = [] for iter_f in range(len(file_list_test)): single_data = np.load(file_list_test[iter_f])['arr_0'] seqtruth.append(single_data) seqtruth = np.array(seqtruth) seqpred = [] current_input_sequence_norm = feature current_density_frame_norm = current_input_sequence_norm[:, -1:, :, :] with torch.no_grad(): for iter_step in range(100): x, y = create_embd_dim(len(current_input_sequence_norm)) x = x.squeeze(1) y = y.squeeze(1) used_input_sequence = torch.cat((current_input_sequence_norm, x, y), dim=1).unsqueeze(2) used_input_sequence = used_input_sequence.permute(0, 2, 3, 4, 1) pred_norm = model(used_input_sequence) pred_norm = pred_norm.squeeze() next_predicted_density_frame_denorm = (pred_norm * std_device) + mean_device next_predicted_density_frame_denorm = next_predicted_density_frame_denorm.detach().cpu().numpy() next_input = pred_norm.unsqueeze(0).unsqueeze(0) seqpred.append(next_predicted_density_frame_denorm) current_input_sequence_norm = torch.cat([ current_input_sequence_norm[:, 1:, :, :], next_input ], dim=1) seqpred = np.array(seqpred) vmin = np.min(seqtruth) vmax = np.max(seqtruth) for iter_p in range(len(seqpred)): fig, axes = plt.subplots(1, 2, figsize=(10, 6)) im0 = axes[0].imshow(seqtruth[iter_p,0], cmap='viridis', vmin=vmin, vmax=vmax) axes[0].set_title(f"Ground Truth - t={iter_p}") axes[0].axis('off') im1 = axes[1].imshow(seqpred[iter_p], cmap='viridis', vmin=vmin, vmax=vmax) axes[1].set_title(f"Prediction - t={iter_p}") axes[1].axis('off') fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04) plt.suptitle(f"Timestep {iter_p+1} - Ground Truth vs. Prediction (FNO)", fontsize=14) plt.tight_layout(rect=[0, 0, 1, 0.95]) fig.savefig(f"res_turb_mgfno2/comparison_t{iter_p:03d}.png", dpi=150) plt.close(fig) ``` 평가 FNO 모델은 U-net 모델보다 크기가 훨씬 작아 20MB 정도로, VRAM 부담이 적습니다. 테스트 결과, FNO는 U-net보다 더 적절하게 난류의 운동 패턴을 포착하였습니다. 특히, FNO 버전에서는 너무 빠른 운동이 없으며, 실제 데이터와 유사한 패턴을 보여주었습니다. FNO는 대규모 데이터셋에서의 효율적인 예측을 위해 설계되었으며, 주파수 공간에서의 필터링과 역 변환을 통해 물리적 현상을 효과적으로 모델링할 수 있습니다. 이는 난류와 같은 복잡한 유체 역학 문제에 대한 실시간 예측에 큰 도움이 될 것으로 기대됩니다.

Related Links