
초록
우리는 METR-LA 교통 예측 과제에서 기존에 최고 성능을 기록했던 Graph WaveNet의 성능을 향상시키기 위한 일련의 개선 사항을 제안한다. 이 과제의 목적은 과거 1시간 동안의 센서 측정 데이터를 이용하여 네트워크 내 각 센서의 향후 교통 속도를 예측하는 것이다. Graph WaveNet(GWN)은 근접한 센서로부터의 정보를 집계하기 위해 그래프 컨볼루션을, 과거 정보를 집계하기 위해 확대 컨볼루션(dilated convolutions)을 번갈아 사용하는 공간-시간 그래프 신경망이다. 우리는 다음과 같은 세 가지 방식으로 GWN을 개선한다: (1) 더 나은 하이퍼파라미터 사용, (2) 초기 컨볼루션 계층으로 더 큰 기울기가 되돌아올 수 있도록 연결 추가, (3) 더 쉬운 단기 교통 예측 과제에서 사전 학습(pretraining) 수행. 이러한 개선 사항들은 METR-LA 과제에서 평균 절대 오차(mean absolute error)를 0.06 감소시켜, GWN이 이전 모델 대비 달성했던 성능 향상과 거의 동등한 수준을 기록한다. 이러한 개선 효과는 PEMS-BAY 데이터셋에도 유사한 상대적 정도로 일반화된다. 또한, 단기 및 장기 예측을 위한 별도의 모델을 앙상블하는 방식이 성능 향상에 더 큰 기여를 함을 보여준다. 코드는 https://github.com/sshleifer/Graph-WaveNet 에서 공개되어 있다.