Close
About
FAQ
Home
Login
USC Login
Register
0
Selected
Invert selection
Deselect all
Deselect all
Click here to refresh results
Click here to refresh results
USC
/
Digital Library
/
University of Southern California Dissertations and Theses
/
Trustworthy spatiotemporal prediction models
(USC Thesis Other)
Trustworthy spatiotemporal prediction models
PDF
Download
Share
Open document
Flip pages
Contact Us
Contact Us
Copy asset link
Request this asset
Transcript (if available)
Content
Trustworthy Spatiotemporal Prediction Models
by
Chuizheng Meng
A Dissertation Presented to the
FACULTY OF THE USC GRADUATE SCHOOL
UNIVERSITY OF SOUTHERN CALIFORNIA
In Partial Fulfillment of the
Requirements for the Degree
DOCTOR OF PHILOSOPHY
(Computer Science)
May 2024
Copyright 2024 Chuizheng Meng
Dedication
This work is dedicated to my advisor, Prof. Yan Liu, my defense committee members, Prof. Willie Neiswanger
and Prof. Assad A Oberai, my labmates and collaborators, and my beloved family.
ii
Acknowledgements
First and foremost, I would like to express my appreciation to my advisor, Prof. Yan Liu, who has not only
provided invaluable advice, support, and patience on research during my Ph.D. journey, but also has shared
priceless experience in many aspects of life that encourage me to accomplish my Ph.D. study.
I would also like to thank Prof. Willie Neiswanger and Prof. Assad A Oberai for serving as the committee
members of my dissertation defense and providing beneficial suggestions for improvement.
I also appreciate all my labmates and collaborators, including but not limited to: Dr. Sungyong Seo,
Prof. Sirisha Rambhatla, Jianke Yang, Dr. Hao Niu, Dr. Guillaume Habault, Dr. Roberto Legaspi, Dr. Shinya
Wada, Dr. Chihiro Ono, and all members in Melady Lab. Withour their precious effort and input in works
presented in the dissertation, it would be impossible for me to finish it.
Last and most importantly, I want to deeply thank my beloved family. During the most difficult time in
my Ph.D. study, it is their encouragement and support that help me go through all the obstacles.
iii
Table of Contents
Dedication . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ii
Acknowledgements . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . iii
List of Tables . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . vii
List of Figures . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ix
Abstract . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . xi
Chapter 1: Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 1
Chapter 2: Literature Review . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 7
2.1 Spatio-Temporal Data Forecasting and Physics Informed Machine Learning . . . . . . . . . 7
2.1.1 Forecasting From Spatiotemporal Data . . . . . . . . . . . . . . . . . . . . . . . . . 7
2.1.2 Long-Term Time Series Forecasting . . . . . . . . . . . . . . . . . . . . . . . . . . . 8
2.1.3 Physics-Informed Modeling of Spatiotemporal Data . . . . . . . . . . . . . . . . . . 8
2.2 Federated Learning and Graph Learning . . . . . . . . . . . . . . . . . . . . . . . . . . . . 9
Chapter 3: Physics-aware Difference Graph Networks for Sparsely-Observed Dynamics . . . . . . 11
3.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 11
3.2 Physics-aware Difference Graph Network . . . . . . . . . . . . . . . . . . . . . . . . . . . . 14
3.2.1 Difference Operators on Graph . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 14
3.2.2 Difference Operators on Triangulated Mesh . . . . . . . . . . . . . . . . . . . . . . 15
3.2.3 Spatial Difference Layer . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 16
3.2.4 Recurrent Graph Networks . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 17
3.3 Effectiveness of Spatial Difference Layer . . . . . . . . . . . . . . . . . . . . . . . . . . . . 19
3.3.1 Approximation of Directional Derivatives . . . . . . . . . . . . . . . . . . . . . . . 19
3.3.2 Graph Signal Prediction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 21
3.4 Prediction: Graph Signals on Land-based Weather Sensors . . . . . . . . . . . . . . . . . . 24
3.4.1 Experimental Set-up . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 24
3.4.2 Graph Signal Predictions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 27
3.4.3 Contribution of Spatial Derivatives . . . . . . . . . . . . . . . . . . . . . . . . . . . 28
3.5 Evaluation on NEMO sea surface temperature (SST) dataset . . . . . . . . . . . . . . . . . 29
3.6 Ablation Study . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 30
3.6.1 Effect of different graph structures . . . . . . . . . . . . . . . . . . . . . . . . . . . 30
3.6.2 Evaluation on datasets with different sparsity . . . . . . . . . . . . . . . . . . . . . 31
3.7 The distribution of prediction error across nodes . . . . . . . . . . . . . . . . . . . . . . . . 32
iv
3.8 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 33
Chapter 4: Physics-Informed Long-Sequence Forecasting From Multi-Resolution Spatiotemporal Data 34
4.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 34
4.2 Problem Formulation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 36
4.3 Spatiotemporal Koopman Multi-Resolution Network . . . . . . . . . . . . . . . . . . . . . 38
4.3.1 Physics-Informed Modeling of Intra-Resolution Dynamics . . . . . . . . . . . . . . 39
4.3.1.1 Neural Network Based Module (ST-Encoder and ST-Decoder) . . . . . . 39
4.3.1.2 Koopman Module . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 39
4.3.1.3 Gated Fusion of ST-Encoder and Koopman Module . . . . . . . . . . . . 40
4.3.2 Inter-Resolution Dynamics Modeling . . . . . . . . . . . . . . . . . . . . . . . . . . 41
4.3.2.1 Self-Attention Module . . . . . . . . . . . . . . . . . . . . . . . . . . . . 41
4.3.2.2 Upsampling and Downsampling Modules . . . . . . . . . . . . . . . . . . 41
4.3.3 Loss Function . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 43
4.4 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 43
4.4.1 Long-Sequence Forecasting With Fully Observed Input . . . . . . . . . . . . . . . . 47
4.4.2 Long-Sequence Forecasting Results With Partially Observed Input . . . . . . . . . 49
4.4.3 Ablation Study . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 51
4.4.4 Effect of the number of resolutions on the performance. . . . . . . . . . . . . . . . 52
4.4.5 Interpretability of Koopman Module: Revealing Dynamics in Each Resolution . . . 53
4.5 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 54
Chapter 5: Cross-Node Federated Graph Neural Network for Spatio-Temporal Data Modeling . . . 55
5.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 55
5.2 Cross-Node Federated Graph Neural Network . . . . . . . . . . . . . . . . . . . . . . . . . 57
5.2.1 Problem Formulation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 57
5.2.2 Proposed Method . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 58
5.2.2.1 Modeling of Node-Level Temporal Dynamics . . . . . . . . . . . . . . . . 60
5.2.2.2 Modeling of Spatial Dynamics . . . . . . . . . . . . . . . . . . . . . . . . 62
5.2.2.3 Alternating Training of Node-Level and Spatial Models . . . . . . . . . . 62
5.3 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 63
5.3.1 Spatio-Temporal Data Modeling: Traffic Flow Forecasting . . . . . . . . . . . . . . 64
5.3.2 Inductive Learning on Unseen Nodes . . . . . . . . . . . . . . . . . . . . . . . . . . 71
5.3.3 Ablation Study: Effect of Alternating Training and FedAvg on Node-Level and
Spatial Models . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 73
5.3.4 Ablation Study: Effect of Client Rounds and Server Rounds . . . . . . . . . . . . . 74
5.4 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 75
Chapter 6: Sample-Level Prototypical Federated Learning . . . . . . . . . . . . . . . . . . . . . . . 76
6.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 76
6.2 Proposed Method: Sample-Level Prototypical Federated Learning (SL-PFL) . . . . . . . . . 78
6.2.1 Problem Formulation . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 78
6.2.2 Sample-Level Factorization of Data Distribution . . . . . . . . . . . . . . . . . . . . 79
6.2.3 Methodology . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 80
6.3 Experiments . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 88
6.3.1 Settings . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 88
6.3.2 Hyperparameter Tuning . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 91
6.3.3 Prediction Results . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 93
v
6.3.4 Ablation Study: Effect of Loss Function Components . . . . . . . . . . . . . . . . . 95
6.3.5 Ablation Study: Form of the Posterior Domain Distribution . . . . . . . . . . . . . 95
6.3.6 Ablation Study: Effect of Numbers of Local and Global Domains in SL-PFL . . . . . 96
6.3.7 Evaluation of Potential Privacy Leaks From Centroids . . . . . . . . . . . . . . . . 97
6.3.8 Convergence Behavior of SL-PFL . . . . . . . . . . . . . . . . . . . . . . . . . . . . 98
6.3.9 Computation And Communication Costs of SL-PFL . . . . . . . . . . . . . . . . . . 98
6.4 Conclusion . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 99
Chapter 7: Conclusions . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 100
Bibliography . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 103
vi
List of Tables
1.1 Summary of contributions of presented works in my dissertation. . . . . . . . . . . . . . . 2
3.1 Mean squared error (10−2
) for approximation of directional derivatives. . . . . . . . . . . . 20
3.2 Mean absolute error (10−2
) for graph signal prediction. . . . . . . . . . . . . . . . . . . . . 23
3.3 Numbers of learnable parameters. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 26
3.4 Graph signal prediction results (MAE) on multistep predictions. . . . . . . . . . . . . . . . 28
3.5 Mean absolute error (10−2
) for SST graph signal prediction. . . . . . . . . . . . . . . . . . 30
3.6 Mean absolute error (10−2
) for graph signal prediction on the synthetic dataset. . . . . . . 30
3.7 Graph signal prediction results (MAE) on multistep predictions with different graph
structures. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 31
3.8 Mean absolute error (10−2
) for graph signal prediction with different sparsity. . . . . . . . 32
3.9 Mean squared error (10−2
) for approximations of directional derivatives of function
f2(x, y) = sin (x) + cos (y) with different sparsity. . . . . . . . . . . . . . . . . . . . . . . 32
4.1 Datasets and tasks. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 46
4.2 Details of data splits. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 46
4.3 Forecasting results within the maximum possible horizon from fully observed input
sequences. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 47
4.4 Forecasting results within multiple horizons from fully observed input sequences. . . . . . 48
4.5 Forecasting results with partially observed input (YellowCab, Horizon=10d). . . . . . . . . 50
4.6 Ablation study results on YellowCab (Obs Ratio = 0.8) with the relative change of errors
after removing each component. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 51
vii
4.7 Effect of the number of resolutions on prediction performance. . . . . . . . . . . . . . . . . 52
5.1 Table of notations. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 58
5.2 Statistics of datasets PEMS-BAY and METR-LA. . . . . . . . . . . . . . . . . . . . . . . . . 64
5.3 Parameters used for calculating the communication cost of GRU + FMTL. . . . . . . . . . . 68
5.4 Parameters used for calculating the communication cost of CNFGNN (AT + FedAvg). . . . 69
5.5 Parameters used for calculating the communication cost of CNFGNN (SL). . . . . . . . . . 69
5.6 Parameters used for calculating the communication cost of CNFGNN (SL + FedAvg). . . . . 69
5.7 Parameters used for calculating the communication cost of CNFGNN (AT, w/o FedAvg). . . 70
5.8 Comparison of performance on the traffic flow forecasting task. . . . . . . . . . . . . . . . 70
5.9 Comparison of the computation cost on edge devices and the communication cost. . . . . 70
5.10 Inductive learning performance measured with rooted mean squared error (RMSE). . . . . 71
5.11 Comparison of test error (RMSE) and the communication cost during training of different
training strategies of CNFGNN. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 73
6.1 Statistics of datasets. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 91
6.2 Architecture of the 5-layer CNN used in MiniDomainNet. . . . . . . . . . . . . . . . . . . . 93
6.3 Average prediction errors across clients. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 94
6.4 Comparison of the effect of different loss function components across clients. . . . . . . . 95
6.5 Comparison of the effect of different posterior domain distribution forms on the average
prediction errors across clients. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 95
6.6 Average prediction performance (RMSE) across clients on Shifts Weather with varying
kclient and K. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 96
6.7 Accuracy of membership inference with centroids on Shifts Weather. . . . . . . . . . . . . 97
viii
List of Figures
3.1 Examples of difference operators applied to graph signal. Filters used for the processing are
(b) P
j
(fi − fj ) (c) P
j
(1.1fi − fj ), (d) fj − 0.5fi
. . . . . . . . . . . . . . . . . . . . . . . . 15
3.2 Physics-aware Difference Graph Networks for graph signal prediction. Blue boxes have
learnable parameters and all parameters are trained through end-to-end learning. The
nodes/edges can be multidimensional. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 17
3.3 Directional derivative on graph . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 19
3.4 Gradients and graph structure of sampled points. Left: the synthetic function is
f1(x, y) = 0.1x
2 + 0.5y
2
. Right: the synthetic function is f2(x, y) = sin(x) + cos(y). . . . 20
3.5 Synthetic dynamics and graph structure of sampled points. . . . . . . . . . . . . . . . . . . 24
3.6 Weather stations in (left) western (right) southeastern states in the United States and k-NN
graph. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 25
3.7 MAE across the nodes. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 32
4.1 Example of multi-resolution spatiotemporal data. . . . . . . . . . . . . . . . . . . . . . . . 35
4.2 ST-KMRN’s architecture. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 38
4.3 Visualization of spatial resolutions. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 45
4.4 Eigenvalues of Koopman matrices in the learned hidden space of input sequences in
different resolutions. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 52
5.1 Overview of Cross-Node Federated Graph Neural Network. . . . . . . . . . . . . . . . . . . 59
5.2 The histograms of data on the first 100 nodes ranked by ID. . . . . . . . . . . . . . . . . . . 65
5.3 Visualization of subgraphs visible in training under different ratios. . . . . . . . . . . . . . 72
5.4 Validation loss during the training stage of different training strategies. . . . . . . . . . . . 73
ix
5.5 Effect of client rounds and server rounds (Rc, Rs) on forecasting performance and
communication cost. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 75
6.1 Overview of SL-PFL. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 82
6.2 Data Partition Visualization of Shifts Weather and MiniDomainNet. . . . . . . . . . . . . . 90
6.3 Test Error Curves of Baselines and SL-PFL. . . . . . . . . . . . . . . . . . . . . . . . . . . . 97
x
Abstract
With the great success of data-driven machine learning methods, concerns with the trustworthiness of
machine learning models have been emerging in recent years. From the modeling perspective, the lack
of trustworthiness amplifies the effect of insufficient training data. Purely data-driven models without
constraints from domain knowledge tend to suffer from over-fitting and losing the generalizability on unseen
data. Meanwhile, concerns with data privacy further obstruct the availability of data from more providers.
On the application side, the absence of trustworthiness hinders the application of data-driven methods
in domains such as spatiotemporal forecasting, which involves data from critical applications including
traffic, climate and energy. My dissertation constructs spatiotemporal prediction models with enhanced
trustworthiness from both the model and the data aspects. For the model trustworthiness, the dissertation
focuses on improve the generalizability of models via the integration of physics knowledge. For the data
trustworthiness, the proposal proposes a spatiotemporal forecasting model in the federated learning context,
where data in a network of nodes is generated locally on each node and remains decentralized. Furthermore,
the dissertation amalgamates the trustworthiness from both aspects and combine the generalizability
of knowledge informed models with the privacy preservation of federated learning for spatiotemporal
modeling.
xi
Chapter 1
Introduction
Data-driven modeling approaches represented by deep learning have not only achieved great success in
typical artificial intelligence applications such as computer vision and natural language processing, but also
shown its potential to enhance numerical methods in science and engineering domains including partial
differential equation solving [76, 77], weather forecasting [78], traffic flow prediction [54, 101, 11] and
turbulence modeling [97, 70, 117]. The rich capacity and expressivity of neural networks based on the
universal approximation theorem [30] and the abundance of data has consistently been the root of their
success, albeit the variety of domains.
In spite of the great success of deep learning, concerns of its trustworthiness have also been emerging
with its achievement. As pointed out in [10], generalizability, interpretability and privacy-preserving
compose the crucial requirements of a trustworthy machine learning model. While the study of trustworthy
AI has incorporated the development of AI in almost every domain it involves, the trustworthiness of
spatiotemporal prediction models are of particular interest to us. Besides the fact that many applications in
science and engineering have spatiotemporal data as the main object of research and put high emphasis
on the trustworthiness of models, the properties of spatiotemporal data also brings both challenges and
opportunities to the development of trustworthy machine learning models.
1
The dissertation categorizes the trustworthiness of spatiotemporal prediction model into two categories: vertical and horizontal trustworthiness. Vertical trustworthiness covers the generalizability and
interpretability of the model. With the same amount of available training data, a model with better vertical
trustworthiness extracts more accurate information of the underlying dynamics and suffers from a lower
risk of over-fitting. In comparison, horizontal trustworthiness emphasizes the privacy-perserving property
of models. A model with stronger horizontal trustworthiness is easier to expand to more data providers and
utilize a larger amount of data. In the following text, I will discuss the unique challenges and opportunities
when spatiotemporal data meets each of the two categories of trustworthiness, as well as the approaches in
the dissertation to address them. Table 1.1 summarizes the contributions of works in my dissertation.
Table 1.1: Summary of contributions of presented works in my dissertation.
Trustworthiness Focuses Content Contributions
Vertical Generalizability,
interpretability.
Chapter 3
Incorporating the knowledge of numerical approximation of
differential operators in human-understandable explicit form
for a more generalizable and interpretable spatiotemporal
prediction model.
Chapter 4
Incorporating the Koopman theory for an interpretable and
generalizable spatiotemporal modeling of long and multi-resolution
sequences.
Horizontal Privacy preservation
through distributed data. Chapter 5
Modeling the complex spatiotemporal dependencies among clients
while maintaining the data decentralized.
Comprehensive Combining the trustworthiness
from both aspects. Chapter 6
Combining the generalizability of knowledge-informed models
with the privacy preservation of federated learning.
Vertical trustworthiness Spatiotemporal datasets are commonly collected as data points in discrete
spatial and temporal steps of real-world physics phenomena or simulated physics processes, which are
further constrained by physics laws. While the physics constraints raise stricter requirements to datadriven models compared to general approximation problems, the dominating physics laws provide extra
knowledge about the underlying dynamics and have the potential to reduce the search space of optimization.
Meanwhile, as a human-understandable form of knowledge, physics knowledge provides a new path
leading to interpretable machine learning models. On one hand, computation graphs and training objectives
designed in alignment with physics laws produce physically meaningful intermediate variables, outputs
2
and loss functions instead of leaving them in the abstract hidden space. On the other hand, analyzing such
variables after data-driven optimization based on their roles in physics laws in turn help reveal properties
of input data and the focus of trained models.
Partial differential equations (PDEs) have long been the primary mathematical tool to describe the
physics laws behind various spatiotemporal processes, and the dissertation incorporates the knowledge of
numerical PDE solvers into data-driven models as the first step to improve the vertical trustworthiness.
As the approximation of derivatives in continuous domain, difference operators have been used as a core
role to compute numerical solutions of PDEs (e.g., Navier–Stokes equations). Therefore, the dissertation
presents physics-aware difference graph networks (PA-DGN) [88] whose architecture is motivated to
leverage differences of sparsely available data from the physical systems. The differences are particularly
important since most of the physics-related dynamic equations handle differences of physical quantities
in spatial and temporal space instead of using the quantities directly. Inspired by the property, PA-DGN
first constructs spatial difference layer (SDL) to efficiently learn the local representations by aggregating
neighboring information in the sparse data points. SDL closely aligns with the computation process of
numerical difference operators, but integrates a data-driven neural network layer to infer the coefficients in
numerical methods. The layer is based on graph networks (GN) as it easily leverages structural features to
learn the localized representations and share the parameters for computing the localized features. As a
result, SDL combines the advantage of expressivity from data-driven models and the interpretability of
a numerical method based computation graph. SDL is followed by recurrent graph networks (RGN) to
predict the temporal difference which is another core component of physics-related dynamic equations.
PA-DGN is applicable to various tasks and the work provides two representative tasks: the approximation
of directional derivatives and the prediction of graph signals.
Besides PDEs, the linear approximation of nonlinear dynamic systems has also been lying in the central
portion of the modeling of spatiotemporal processes. Not only does it provide a practical modeling approach
3
for long sequences from complex nonlinear system, but also it reveals key properties of the dynamics
including the magnitudes of frequencies of each component. In comparison, data-driven models have
been struggling with the modeling of long sequences due to the computation and memory costs brought
by the long dependency paths across time steps. Meanwhile, they lack the interpretability in the linear
approximation approach. Inspired by the advantages of the linear approximation approach, the dissertation
presents Spatiotemporal Koopman Multi-Resolution Network (ST-KMRN) [68] to address the challenge of
long sequence modeling. ST-KMRN combines the Koopman theory-based modeling of dynamic systems
with deep learning-based prediction modules for more principled modeling of dynamics. The inferred
Koopman matrix offers interpretations of heterogeneous dynamics in different resolutions at the same time.
Horizontal trustworthiness Vast amount of data generated from networks of sensors, wearables, and
the Internet of Things (IoT) devices underscores the need for advanced modeling techniques that leverage
the spatiotemporal structure of decentralized data due to the need for edge computation and licensing
(data access) issues. While federated learning (FL) has emerged as a framework for model training without
requiring direct data sharing and exchange, effectively modeling the complex spatiotemporal dependencies
to improve forecasting capabilities still remains an open problem. On the other hand, state-of-the-art
spatiotemporal forecasting models assume unfettered access to the data, neglecting constraints on data
sharing.
To bridge this gap, the dissertation presents a federated spatiotemporal model – Cross-Node Federated
Graph Neural Network (CNFGNN) [69] – which explicitly encodes the underlying graph structure using
graph neural network (GNN)-based architecture under the constraint of cross-node federated learning,
which requires that data in a network of nodes is generated locally on each node and remains decentralized.
CNFGNN operates by disentangling the temporal dynamics modeling on devices and spatial dynamics on
the server, utilizing alternating optimization to reduce the communication cost, facilitating computations
on the edge devices. Experiments on the traffic flow forecasting task show that CNFGNN achieves the best
4
forecasting performance in both transductive and inductive learning settings with no extra computation
cost on edge devices, while incurring modest communication cost.
Comprehensive trustworthiness Compared to the results of spatiotemporal prediction models with
vertical trustworthiness, which gain both generalizability and interpretability from prior knowledge,
prediction models satisfying horizontal trustworthiness still have to sacrifice the prediction performance as
the cost of fulfilling the requirement of distributed data. Luckily but not unexpectedly, the combination
of both vertical and horizontal trustworthiness may have a strong potential to shrink the gap. The
improvement from CNFGNN already illustrates that correct inductive bias such as relations among clients
will help mitigate the data heterogeneity and improve the performance under the constraint of horizontal
trustworthiness, implying that prior knowledge in models and constraints may also be beneficial.
As a comprehensive solution to the trustworthiness of spatiotemporal prediction models, the dissertation
presents a spatiotemporal prediction model under the federated learning context, which combines the
generalizability of knowledge-informed models with the privacy preservation of federated learning. In
federated learning, as a result of data locality, data is usually not identically or independently (non-IID)
distributed across clients, and the non-IID property has long been the key challenge in FL. Furthermore,
in real-world cross-silo scenarios, it is ubiquitous that clients are organizations owning private data from
multiple domains internally, which exacerbates the non-IID issue. For example, in healthcare applications,
each client (hospital) gathers data from patients with heterogeneous demographics. While previous works
have made efforts to address the non-IID challenge across clients by assuming various relations among
client-level data distributions and enabling personalized models at the client level, they ignore the internal
data heterogeneity within each client or require explicit data domain indicators, which are hardly accessible
in real-world data. In this dissertation, I propose Sample-Level Prototypical Federated Learning (SL-PFL) to
bridge the gap. SL-PFL incorporates prototypical learning under the FL framework and provides a finegrained personalized model for each data sample instead of learning one uniform model for all samples of
5
each client. Meanwhile, it can be trained using data without ground-truth domain indicators. Experimental
results demonstrate that SL-PFL outperforms existing FL methods with a global model or client-level
personalized models on various real-world regression and classification tasks from weather, computer
vision, and healthcare applications.
6
Chapter 2
Literature Review
2.1 Spatio-Temporal Data Forecasting and Physics Informed Machine
Learning
The dissertation involves elements from forecasting from spatiotemporal data, long-term time series
forecasting, and physics-informed modeling of spatiotemporal data. In the following subsections, we
discuss related works in each domain.
2.1.1 Forecasting From Spatiotemporal Data
In recent years, methods based on Graph Neural Networks (GNNs) have shown superior performance
in forecasting tasks from spatiotemporal data in various domains including physical simulation [40, 32],
traffic [54, 102, 101, 108, 112, 11], human motion [104, 40, 32], and climate [87, 33]. GNN-based methods
typically utilize recurrent neural networks or one-dimensional convolution neural networks to capture
temporal dependencies, and message passing or graph convolutions to model spatial dependencies. While
existing works achieve state-of-the-art performance on spatiotemporal data with a single resolution, they
lack the modeling of multi-resolution data. Although data in multiple resolutions can be processed as extra
input features, it is not optimal for fully exploiting the rich contextual information among resolutions.
7
2.1.2 Long-Term Time Series Forecasting
One accompanying problem with multi-resolution forecasting is long-term time series forecasting as multistep predictions for coarse temporal resolutions involve long prediction horizons into the future. [113]
proposes an efficient Transformer-based method for modeling long time series in uni-resolution, while we
will demonstrate in this work that forecasting tasks on multiple resolutions can benefit long-term prediction
performance.
2.1.3 Physics-Informed Modeling of Spatiotemporal Data
As the underlying processes of spatiotemporal data are usually governed by physics laws, physics-informed
methods have potential to further improve the performance of neural network models via incorporating
inductive bias.
Domain-Specific Knowledge Domain-specific knowledge provides effective inductive bias for solving
problems within one specific domain. For example, [98, 38] incorporate domain knowledge as regularizations in deep learning models to improve the performance of turbulence simulation and lake temperature
prediction respectively. However, when we need a model addressing spatiotemporal modeling tasks from
multiple domains, general knowledge of dynamic systems needs to be integrated.
Numerical Methods of Solving Ordinary/Partial Differential Equations (ODEs/PDEs) As the
physics laws governing various spatiotemporal processes can be described with ODEs and PDEs in similar
forms, incorporating numerical methods generally applicable to ODEs and PDEs can benefit tasks in
multiple domains. [13, 79] propose scalable back propagation methods through numerical ODE solvers and
enable the modeling of irregular time series. [25, 34, 87] perform convolutions constrained by numerical
PDE solvers and achieve better performance compared to unconstrained convolutions.
Koopman Theory Based Methods Koopman theory is based on the insight that the state space of
a non-linear dynamic system can be encoded into an infinite-dimensional space where the dynamics is
8
linear [43]. In practice, people assume the infinite-dimensional space can be approximated with a finitedimensional space. The key problem is then to find a proper pair of encoder/decoder to map from/to the
state space to/from the hidden space.
Traditionally, people construct the encoder/decoder with hand-crafted functions, such as the identity
function in Dynamic Mode Decomposition (DMD) [85], nonlinear functions in Extended DMD (EDMD) [99],
and kernel functions in Kernel DMD (KDMD) [39]. However, hand-crafted functions may fail to fit complex
dynamic systems and are hard to design without domain-specific knowledge. Thus, recent works [56,
3, 62] construct encoders/decoders using neural networks as trainable universal approximators. They
demonstrate that the combination of neural networks and Koopman theory achieves comparable or even
higher performance than the Koopman approximators with hand-crafted mapping functions, while enjoying
the ability to generalize to multiple datasets with the same design. [56] further shows that the integration
of Koopman theory allows the model to adapt to new systems with unknown dynamics faster than pure
neural networks.
2.2 Federated Learning and Graph Learning
The dissertation derives elements from federated learning, graph neural networks and privacy-preserving
graph learning, we now discuss related works in these areas.
Federated Learning (FL). Federated learning is a machine learning setting where multiple clients train
a model in collaboration with decentralized training data [35]. It requires that the raw data of each client is
stored locally without any exchange or transfer. However, the decentralized training data comes at the
cost of less utilization due to the heterogeneous distributions of data on clients and the lack of information
exchange among clients. Various optimization algorithms have been developed for federated learning on
non-IID and unbalanced data [66, 51, 37]. [92] propose a multi-task learning framework that captures
9
relationships amongst data. While the above works mitigate the caveat of missing neighbors’ information
to some extent, they are not as effective as GNN models and still suffer from the absence of feature exchange
and aggregation.
Graph Neural Networks (GNNs). GNNs have shown their superior performance on various learning
tasks with graph-structured data, including graph embedding [26], node classification [42], spatio-temporal
data modeling [104, 54, 108] and multi-agent trajectory prediction [5, 40, 49]. Recent GNN models [26,
106, 107, 31] also have sampling strategies and are able to scale on large graphs. While GNNs enjoy the
benefit from strong inductive bias [6, 103], most works require centralized data during the training and the
inference processes.
Privacy-Preserving Graph Learning. [93] and [67] use statistics of graph structures instead of node
information exchange and aggregation to avoid the leakage of node information. Recent works have also
incorporated graph learning models with privacy-preserving techniques such as Differential Privacy (DP),
Secure Multi-Party Computation (MPC) and Homomorphic Encryption (HE). [114] utilize MPC and HE
when learning a GNN model for node classification with vertically split data to preserve silo-level privacy
instead of node-level privacy. [81] preprocesses the input raw data with DP before feeding it into a GNN
model. Composing privacy-preserving techniques for graph learning can help build federated learning
systems following the privacy-in-depth principle, wherein the privacy properties degrade as gracefully as
possible if one technique fails [35].
10
Chapter 3
Physics-aware Difference Graph Networks for Sparsely-Observed
Dynamics
3.1 Introduction
Modeling real world phenomena, such as climate observations, traffic flow, physics and chemistry simulation [55, 22, 60, 9, 82, 24], is important but extremely challenging. While deep learning has achieved
remarkable successes in prediction tasks by learning latent representations from data-rich applications
such as image recognition [44], text understanding [100], and speech recognition [29], we confront many
challenging scenarios in modeling natural phenomena with deep neural networks when only a limited
number of observations are available. Particularly, the sparsely available data points cause substantial
numerical error when we utilize existing finite difference operators and the limitation requires a more
principled way to redesign deep learning models.
While many methods have been proposed to model physics-simulated observations using deep learning,
many of them are designed under the assumption that input is on a continuous domain. For example, [76,
77] proposed physics-informed neural networks (PINNs) to learn nonlinear relations between input (spatialand temporal-coordinates (x, t)) and output simulated with a given partial differential equation (PDE).
11
Since [76, 77] use the coordinates as input and compute derivatives based on the coordinates to represent
the equation, the setting is only valid when the data are densely observed over spatial and temporal space.
Prior knowledge related to physics equations has been combined with data-driven models for various
purposes. [15] proposed a nonlinear diffusion process for image restoration and [9] incorporated the
transport physics (advection-diffusion equation) with deep neural networks for forecasting sea surface
temperature by extracting the motion field. [63] introduced deep Lagrangian networks specialized to learn
Lagrangian mechanics with learnable parameters. [86] proposed a physics-informed regularizer to impose
data-specific physics equations. In common, the methods in [15, 9, 63] are not efficiently applicable to
sparsely discretized input as only a small number of data points are available and continuous properties on
given space are not easily recovered. It is unsuitable to directly use continuous differential operators to
provide local behaviors because it is hard to approximate the continuous derivatives precisely with the
sparse points [89, 1, 61]. Furthermore, they are only applicable when the specific physics equations are
explicitly given and still hard to be generalized to incorporate other types of equations.
As another direction to modeling physics-simulated data, [60] proposed PDE-Net which uncovers the
underlying hidden PDEs and predicts the dynamics of complex systems. [80] derived new CNNs: parabolic
and hyperbolic CNNs based on ResNet [28] architecture motivated by PDE theory. While [60, 80] are
flexible to uncover hidden physics from the constrained kernels, it is still restrictive to a regular grid where
the proposed constraints on the learnable filters are easily defined.
The topic of reasoning physical dynamics of discrete objects has been actively studied [82, 5, 12] as the
appearance of graph-based neural networks [42, 83, 24]. Although these models can handle sparsely located
data points without explicitly given physics equations, they are purely data-driven so that the physicsinspired inductive bias for exploiting finite differences is not considered at all. In contrast, our method
consists of physics-aware modules allowing efficiently leveraging the inductive bias to learn spatiotemporal
data from the physics system.
12
In this paper, we propose physics-aware difference graph networks (PA-DGN) whose architecture is
motivated to leverage differences of sparsely available data from the physical systems. The differences are
particularly important since most of the physics-related dynamic equations (e.g., Navier–Stokes equations)
handle differences of physical quantities in spatial and temporal space instead of using the quantities
directly. Inspired by the property, we first propose spatial difference layer (SDL) to efficiently learn the
local representations by aggregating neighboring information in the sparse data points. The layer is based
on graph networks (GN) as it easily leverages structural features to learn the localized representations and
share the parameters for computing the localized features. Then, the layer is followed by recurrent graph
networks (RGN) to predict the temporal difference which is another core component of physics-related
dynamic equations. PA-DGN is applicable to various tasks and we provide two representative tasks; the
approximation of directional derivatives and the prediction of graph signals.
Our contributions are:
• We tackle a limitation of the sparsely discretized data which cause numerical error when we model
the physical system by proposing spatial difference layer (SDL) for efficiently exploiting neighboring
information under the limitation of sparsely observable points.
• We combine SDL with recurrent graph networks to build PA-DGN which automatically learns the
underlying spatiotemporal dynamics in graph signals.
• We verify that PA-DGN is effective in approximating directional derivatives and predicting graph
signals in synthetic data. Then, we conduct exhaustive experiments to predict climate observations
from land-based weather stations and demonstrate that PA-DGN outperforms other baselines.
13
3.2 Physics-aware Difference Graph Network
In this section, we introduce the building module used to learn spatial differences of graph signals and
describe how the module is used to predict signals in the physics system.
3.2.1 Difference Operators on Graph
As approximation of derivatives in continuous domain, difference operators have been used as a core role
to compute numerical solutions of (continuous) differential equations. Since it is hard to derive closed-form
expressions of derivatives in real-world data, the difference operators have been considered as alternative
tools to describe and solve PDEs in practice. The operators are especially important for physics-related
data (e.g., meteorological observations) because the governing rules behind the observations are mostly
differential equations.
Graph signals Given a graph G = (V, E) where V is a set of vertices V = {1, . . . , Nv} and E is a set
of edges E ⊆ {(i, j)|i, j ∈ V} (|E| = Ne), graph signals on all nodes at time t are f(t) = {fi(t) | i ∈ V}
where fi
: V → R. Graph signals on edges can also be defined similarly, F(t) = {Fij (t)|(i, j) ∈ E} where
Fij : E → R. Both signals can be multidimensional.
Gradient on graph The gradient (∇) of a function on nodes of a graph is represented by finite difference
∇ : L
2
(V) → L
2
(E), (∇f)ij = (fj − fi) if (i, j) ∈ E and 0 otherwise,
where L
2
(V) and L
2
(E) denote Hilbert spaces for node/edge functions, respectively. The gradients on a
graph provide finite differences of graph signals and they become edge (i, j) features.
14
(a) Original graph signals (b) Detected edge (c) Sharpened signals (d) Modulated gradients
Figure 3.1: Examples of difference operators applied to graph signal. Filters used for the processing are (b)
P
j
(fi − fj ) (c) P
j
(1.1fi − fj ), (d) fj − 0.5fi
.
Laplace-Beltrami operator Laplace-Beltrami operator (or Laplacian, ∆) in graph domain is defined as
∆ : L
2
(V) → L
2
(V), (∆f)i =
X
j:(i,j)∈E
(fi − fj ) ∀i, j ∈ V
This operator is usually regarded as a matrix form in other literature, L = D − A where A is an adjacency
matrix and D = diag(
P
j:j̸=i Aij ) is a degree matrix.
3.2.2 Difference Operators on Triangulated Mesh
According to [19], the gradient and Laplacian operators on the triangulated mesh can be discretized by
incorporating the coordinates of nodes. To obtain the gradient operator, the per-face gradient of each
triangular face is calculated first. Then, the gradient on each node is the area-weighted average of all
its neighboring faces, and the gradient on edge (i, j) is defined as the dot product between the per-node
gradient and the direction vector eij . The Laplacian operator can be discretized with Finite Element Method
(FEM):
(∆f)i =
1
2
X
j:(i,j)∈E
(cot αj + cot βj ) (fj − fi),
where node j belongs to node i’s immediate neighbors (j ∈ Ni
) and (αj , βj ) are two opposing angles of
the edge (i, j).
15
3.2.3 Spatial Difference Layer
While the difference operators are generalized in Riemannian manifolds [46, 58], there exist numerical
error compared to those in continuous space and the error can be larger when the nodes are spatially far
from neighboring nodes because the connected nodes (j ∈ Ni
) of i-th node fail to represent local features
around the node. Furthermore, the error is even larger if available data points are sparsely distributed (e.g.,
sensor-based observations). In other words, the difference operators are unlikely to discover meaningful
spatial variations behind the sparse observations since they are highly limited to immediate neighboring
information only. To mitigate the limitation, we propose spatial difference layer (SDL) which consists of a
set of parameters to define learnable difference operators as a form of gradient and Laplacian to fully utilize
neighboring information:
(
w∇f)ij = w
(g1)
ij (fj − w
(g2)
ij fi), (
w∆f)i =
X
j:(i,j)∈E
w
(l1)
ij (fi − w
(l2)
ij fj ) (3.1)
where wij are the parameters tuning the difference operators along with the corresponding edge direction
eij . The two forms (Eq 3.1) are associated with edge and node features, respectively. The superscript in w∇
and w∆ denotes that the difference operators are functions of the learnable parameters w. w
(g)
ij and w
(l)
ij
are obtained by integrating local information as follow:
wij = g({fk, Fmn | k,(m, n) ∈ h-hop neighborhood of edge (i, j)}) (3.2)
While the standard difference operators consider two connected nodes only (i and j) for each edge (i, j),
Eq 3.2 uses a larger view (h-hop) to represent the differences between i and j nodes. Since graph networks
(GN) [6] are efficient networks to aggregate neighboring information, we use GN for g(·) function and
wij are edge features of output of GN. Eq 3.2 can be viewed as a higher-order difference equation because
nodes/edges which are multi-hop apart are considered.
16
SDL
RGN
Graph signals at time t Predicted graph
signals at time t+1
Modulated spatial differences
Difference graph
Concatenation
Hidden graph Updated hidden graph
Figure 3.2: Physics-aware Difference Graph Networks for graph signal prediction. Blue boxes have learnable
parameters and all parameters are trained through end-to-end learning. The nodes/edges can be multidimensional.
wij has a similar role of parameters in convolution kernels of CNNs. For example, while the standard
gradient operator can be regarded as an example of simple edge-detecting filters, the operator can be a
sharpening filter if w
(g1)
ij = 1 and w
(g2)
ij =
|Ni|+1
|Ni|
for i node and the operators over each edge are summed.
In other words, by modulating wij , it is readily extended to conventional kernels including edge detection or
sharpening filters and even further complicated kernels. On top of wij , the difference forms in Eq 3.1 make
an optimizing process for learnable parameters based on the differences instead of the values intentionally.
Eq 3.1 thus naturally provides the physics-inspired inductive bias which is particularly effective for modeling
physics-related observations. Furthermore, it is easily possible to increase the number of channels for w
(g)
ij
and w
(l)
ij to be more expressive. Figure 3.1 illustrates how the exemplary filters convolve the given graph
signals.
3.2.4 Recurrent Graph Networks
Difference graph Once the modulated spatial differences (w∇f(t),
w ∆f(t)) are obtained, they are
concatenated with the current signals f(t) to construct node-wise (zi
) and edge-wise (zij ) features and
the graph is called a difference graph. The difference graph includes all information to describe spatial
variations.
17
Recurrent graph networks Given a snapshot (f(t), F(t)) of a sequence of graph signals, one difference
graph is obtained and is used to predict next graph signals. While a non-linear layer can be used to combine
the learned spatial differences to predict the next signals, it is limited to discover spatial relations only
among the features in the difference graph. Since many equations describing physics-related phenomena
are non-static (e.g., Navier–Stokes equations), we adopt recurrent graph networks (RGN) [82] with a graph
state Gh as input to combine the spatial differences with temporal variations. RGN returns a graph state
(G
∗
h = (h
∗(v)
, h
∗(e)
)) and next graph signal z
∗
i
and z
∗
ij . The update rule is described as follow:
1. (z
∗
ij , h
∗(e)
) ← ϕ
e
(zij , zi
, zj , h
(e)
) for all (i, j) ∈ E pairs,
2. (z
∗
i
, h
∗(v)
) ← ϕ
v
(zi
, z¯
′
i
, h
(v)
) for all i ∈ V,
z¯
′
i
is an aggregated edge attribute related to the node i,
where ϕ
e
, ϕv
are edge and node update functions, respectively, and they can be any recurrent unit. Finally,
the prediction is made through a decoder by feeding the graph signal, z
∗
i
and z
∗
ij .
Learning objective Let ˆf and Fˆ denote predictions of the target node/edge signals. PA-DGN is trained
by minimizing the following objective:
L =
X
i∈V
||fi − ˆfi
||2 +
X
(i,j)∈E
||Fij − Fˆ
ij ||2
. (3.3)
For multistep predictions, L is summed over all predicting steps. If only one type (node or edge) of signal is
given, the corresponding term in Eq 3.3 is used to optimize the parameters in SDL and RGN simultaneously.
18
f(xi)
f(xj)
f(xi)
eijf(xi)
Figure 3.3: Directional
derivative on graph
3.3 Effectiveness of Spatial Difference Layer
To investigate if the proposed spatial difference forms (Eq 3.1) can be beneficial to learning physics-related
patterns, we use SDL on two different tasks: (1) approximate directional derivatives and (2) predict synthetic
graph signals.
3.3.1 Approximation of Directional Derivatives
As we claimed in Section 3.2.3, the standard difference forms (gradient and Laplacian) on a graph can cause
significant numerical error easily because they are susceptible to a distance of two points and variations of a
given function. To evaluate the applicability of the proposed SDL, we train SDL to approximate directional
derivatives on a graph. First, we define a synthetic function and its gradients on 2D space and sample 200
points (xi
, yi). Then, we construct a graph on the sampled points by using k-NN algorithm (k = 4). With
the known gradient
∇f = ( ∂f
∂x ,
∂f
∂y )
at each point (a node in the graph), we can compute directional
derivatives by projecting ∇f to a connected edge eij (See Figure 3.3). We compare against four baselines:
(1) the finite gradient (FinGrad) (2) multilayer perceptron (MLP) (3) graph networks (GN) (4) a different
form of Eq 3.1 (One-w). For the finite gradient ((fj − fi)/||xj − xi
||), there is no learnable parameter and
it only uses two points. For MLP, we feed (fi
, fj , xi
, xj ) as input to see whether learnable parameters can
benefit the approximation or not. For GN, we use distances of two connected points as edge features and
19
Figure 3.4: Gradients and graph structure of sampled points. Left: the synthetic function is f1(x, y) =
0.1x
2 + 0.5y
2
. Right: the synthetic function is f2(x, y) = sin(x) + cos(y).
function values on the points as node features. The edge feature output of GN is used as a prediction for
the directional derivative on the edge. Finally, we modify the proposed form as (
w∇f)ij = wijfj − fi
. GN
and the modified form are used to verify the effectiveness of Eq 3.1. Note that we define two synthetic
functions (Figure 3.4) which have different property; (1) monotonically increasing from a center and (2)
periodically varying.
Table 3.1: Mean squared error (10−2
) for approximation of directional derivatives.
Functions FinGrad MLP GN One-w SDL
f1(x, y) = 0.1x
2 + 0.5y
2
6.42±0.47 2.12±0.32 1.05±0.42 1.41±0.44 0.97±0.39
f2(x, y) = sin(x) + cos(y) 5.90±0.04 2.29±0.77 2.17±0.34 6.73±1.17 1.26±0.05
Approximation accuracy As shown in Table 3.1, the proposed spatial difference layer outperforms
others by a large margin. As expected, FinGrad provides the largest error since it only considers two
points without learnable parameters. It is found that the learnable parameters can significantly benefit to
approximate the directional derivatives even if input is the same (FinGrad vs. MLP). Note that utilizing
neighboring information (GN, One-w, SDL) is generally helpful to learn spatial variations properly. However,
simply training parameters in GN is not sufficient and explicitly defining difference, which is important to
understand spatial variations, provides more robust inductive bias. One important thing we found is that
One-w is not effective as much as GN and it can be even worse than FinGrad. It is because of its limited
degree of freedom. As implied in the form (∇wf)ij = wij ∗ fj − fi
, only one wij adjusts the relative
20
difference between fi and fj , and this is not enough to learn whole possible linear combinations of fi and
fj . The unstable performance supports that the form of SDL is not ad-hoc but more rigorously designed.
3.3.2 Graph Signal Prediction
We evaluate PA-DGN on the synthetic data sampled from the simulation of specific convection-diffusion
equations, to provide if the proposed model can predict next signals of the simulated dynamics from
observations on discrete nodes only.
Simulated Data For the simulated dynamics, we discretize the following partial differential equation
similar to the one in [60] to simulate the corresponding linear variable-coefficient convection-diffusion
equation on graphs.
In a continuous space, we define the linear variable-coefficient convection-diffusion equation as:
∂f
∂t = a(x, y)fx + b(x, y)fy + c(x, y)∆f
f|
t=0 = f0(x, y)
(3.4)
, with Ω = [0, 2π] × [0, 2π], (t, x, y) ∈ [0, 0.2] × Ω, a(x, y) = 0.5(cos(y) + x(2π − x) sin(x)) +
0.6, b(x, y) = 2(cos(y) + sin(x)) + 0.8, c(x, y) = 0.5
1 −
√
(xi−π)
2+(yi−π)
2
√
2π
.
We follow the setting of initialization in [60]:
f0(x, y) = X
|k|,|l|≤N
λk,l cos(kx + ly) + γk,l sin(kx + ly) (3.5)
, where N = 9, λk,l, γk,l ∼ N
0,
1
50
, and k and l are chosen randomly.
2
We use spatial difference operators to approximate spatial derivatives:
fx(xi
, yi) = 1
2s
(f(xi
, yi) − f(xi − s, yi)) −
1
2s
(f(xi
, yi) − f(xi + s, yi))
fy(xi
, yi) = 1
2s
(f(xi
, yi) − f(xi
, yi − s)) −
1
2s
(f(xi
, yi) − f(xi
, yi + s))
fxx(xi
, yi) = 1
s
2
(f(xi
, yi) − f(xi − s, yi)) + 1
s
2
(f(xi
, yi) − f(xi + s, yi))
fyy(xi
, yi) = 1
s
2
(f(xi
, yi) − f(xi
, yi − s)) + 1
s
2
(f(xi
, yi) − f(xi
, yi + s))
(3.6)
, where s is the spatial grid size for discretization.
Then we rewrite (3.4) with difference operators defined on graphs:
∂f
∂t = a(i)(∇f)xˆ + b(i)(∇f)yˆ + c(i)((∆f)xˆxˆ + (∆f)yˆyˆ)
fi(0) = fo(i)
(3.7)
, where
a(i)(xj , yj ) =
a(xi
, yi)
2s
if xi = xj + s, yi = yj
−
a(xi
, yi)
2s
if xi = xj − s, yi = yj
(3.8)
b(i)(xj , yj ) =
b(xi
, yi)
2s
if xi = xj , yi = yj + s
−
b(xi
, yi)
2s
if xi = xj , yi = yj − s
(3.9)
c(i)(xj , yj ) = c
s
2
(3.10)
.
22
Then we replace the gradient w.r.t time in (3.7) with temporal discretization:
f(t + 1) = ∆t(a(i)(∇f)xˆ + b(i)(∇f)yˆ + c(i)((∆f)xˆxˆ + (∆f)yˆyˆ)) + f(t)
fi(0) = fo(i)
(3.11)
, where ∆t is the time step in temporal discretization.
Equation (3.11) is used for simulating the dynamics described by the equation (3.4). Then, we uniformly
sample 250 points in the above 2D space and choose their corresponding time series of u as the dataset
used in our synthetic experiments. We generate 1000 sessions on a 50 × 50 regular mesh with time step
size ∆t = 0.01. 700 sessions are used for training, 150 for validation and 150 for test.
The task is to predict signal values of all points in the future M steps given observed values of the first
N steps. For our experiments, we choose N = 5 and M = 15. Since there is no a priori graph structure on
sampled points, we construct a graph with k-NN algorithm (k = 4) using the Euclidean distance. Figure 3.5
shows the dynamics and the graph structure of sampled points.
Model Settings To evaluate the effect of the proposed SDL on the above prediction task, we cascade
SDL and a linear regression model as our prediction model since the dynamics follows a linear partial
differential equation. We compare its performance with four baselines: (1) vector auto-regressor (VAR);
(2) multilayer perceptron (MLP); (3) StandardOP: the standard approximation of differential operators in
Section 3.2.1 followed by a linear regressor; (4) MeshOP: similar to StandardOP but use the discretization
on triangulated mesh in Section 3.2.2 for differential operators.
Table 3.2: Mean absolute error (10−2
) for graph signal prediction.
VAR MLP StandardOP MeshOP SDL
16.84±0.41 15.75±0.53 11.99±0.29 12.82±0.06 10.87±0.98
23
t = 1 t = 5 t = 10 t = 15 t = 20
Figure 3.5: Synthetic dynamics and graph structure of sampled points.
Prediction Performance Table 3.2 shows the prediction performance of different models measured
with mean absolute error. The prediction model with our proposed spatial differential layer outperforms
other baselines. All models incorporating any form of spatial difference operators (StandardOP, MeshOP
and SDL) outperform those without spatial difference operators (VAR and MLP), showing that introducing
spatial differences information inspired by the intrinsic dynamics helps prediction. However, in cases
where points with observable signal are sparse in the space, spatial difference operators derived with
fixed rules can be inaccurate and sub-optimal for prediction since the locally linear assumption which
they are based on no longer holds. Our proposed SDL, to the contrary, is capable of bridging the gap
between approximated difference operators and accurate ones by introducing learnable coefficients utilizing
neighboring information, and thus improves the prediction performance.
3.4 Prediction: Graph Signals on Land-based Weather Sensors
We evaluate the proposed model on the task of predicting climate observations (Temperature) from the
land-based weather stations located in the United States.
3.4.1 Experimental Set-up
Data and task We sample the weather stations located in the United States from the Online Climate
Data Directory of the National Oceanic and Atmospheric Administration (NOAA) and choose the stations
which have actively measured meteorological observations during 2015. We choose two geographically
24
Figure 3.6: Weather stations in (left) western (right) southeastern states in the United States and k-NN
graph.
close but meteorologically diverse groups of stations: the Western and Southeastern states. We use kNearest Neighbor (NN) algorithm (k = 4) to generate graph structures and the final adjacency matrix is
A = (Ak + A⊤
k
)/2 to make it symmetric where Ak is the output adjacency matrix from k-NN algorithm.
Figure 3.6 shows the distributions of the land-based weather stations and their connectivity. Since the
stations are not synchronized and have different timestamps for the observations, we aggregate the time
series hourly. The 1-year sequential data are split into the train set (8 months), the validation set (2 months),
and the test set (2 months), respectively.
Our main task is to predict the next graph signals based on the current and past graph signals. All
methods we evaluate are trained through the objective (Eq 3.3) with the Adam optimizer and we use
scheduled sampling [8] for the models with recurrent modules. We evaluate PA-DGN and other baselines
on two prediction tasks, (1) 1-step and (2) multistep-ahead predictions. Furthermore, we demonstrate the
ablation study that provides how much the spatial derivatives from our proposed SDL are important signals
to predict the graph dynamics.
Model Settings Unless mentioned otherwise, all models use a hidden dimension of size 64.
25
• VAR: A vector autoregression model with 2 lags. Input is the concatenated features of previous 2
frames. The weights are shared among all nodes in the graph.
• MLP: A multilayer perceptron model with 2 hidden layers. Input is the concatenated features of
previous 2 frames. The weights are shared among all nodes in the graph.
• GRU: A Gated Recurrent Unit network with 2 hidden layers. Input is the concatenated features of
previous 2 frames. The weights are shared among all nodes in the graph.
• RGN: A recurrent graph neural network model with 2 GN blocks. Each GN block has an edge update
block and a node update block, both of which use a 2-layer GRU cell as the update function. We set
its hidden dimension to 73 so that it has the same number of learnable parameters as our proposed
model PA-DGN.
• RGN(StandardOP): Similar to RGN, but use the output of difference operators in Section 3.2.1 as
extra input features. We set its hidden dimension to 73.
• RGN(MeshOP): Similar to RGN(StandardOP), but the extra input features are calculated using
opeartors in Section 3.2.2. We set its hidden dimension to 73.
• PA-DGN: Our proposed model. The spatial derivative layer uses a message passing neural network
(MPNN) with 2 GN blocks using 2-layer MLPs as update functions. The forward network part uses
a recurrent graph neural network with 2 recurrent GN blocks using 2-layer GRU cells as update
functions.
The numbers of learnable parameters of all models are listed as follows:
Table 3.3: Numbers of learnable parameters.
Model VAR MLP GRU RGN RGN(StandardOP) RGN(MeshOP) PA-DGN
# Params 3 4,417 37,889 345,876 341,057 342,152 340,001
26
Training Settings
• The number of evaluation runs: We performed 3 times for every experiment in this paper to
report the mean and standard deviations.
• Length of prediction: For experiments on NOAA datasets, all models take first 12 frames as input
and predict the following 12 frames.
• Training hyper-parameters: We use Adam optimizer with learning rate 1e-3, batch size 8, and
weight decay of 5e-4. All experiments are trained for a maximum of 2000 epochs with early stopping.
All experiments are trained using inverse sigmoid scheduled sampling with the coefficient k = 107.
• Environments: All experiments are implemented with Python3.6 and PyTorch 1.1.0, and are run
with NVIDIA GTX 1080 Ti GPUs.
3.4.2 Graph Signal Predictions
We compare against the widely used baselines (VAR, MLP, and GRU) for 1-step and multistep prediction.
Then, we use RGN [82] to examine how much the graph structure is beneficial. Finally, we evaluate PA-DGN
to verify if the proposed architecture (Eq 3.1) is able to reduce prediction loss. Experiment results for the
prediction task are summarized in Table 3.4.
Overall, RGN and PA-DGN are better than other baselines and it implies that the graph structure
provides useful inductive bias for the task. It is intuitive as the meteorological observations are continuously
changing over the space and time and thus, the observations at the i-th station are strongly related to those
of its neighboring stations.
PA-DGN outperforms RGN and the discrepancy comes from the fact that the spatial derivatives (Eq 3.1)
we feed in PA-DGN are beneficial and this finding is expected because the meteorological signals at a
certain point are a function of not only its previous signal but also the relative differences between neighbor
27
signals and itself. Knowing the relative differences among local observations is particularly essential to
understand physics-related dynamics. For example, Diffusion equation, which describes how physical
quantities (e.g., heat) are transported through space over time, is also a function of relative differences of
the quantities ( df
dt = D∆f) rather than values of the neighbor signals. In other words, spatial differences
are physics-aware features and it is desired to leverage the features as input to learn dynamics related to
physical phenomena.
Table 3.4: Graph signal prediction results (MAE) on multistep predictions. In each row, we report the
average with standard deviations from all baselines and PA-DGN. One step is 1-hour time interval.
Region Method 1-step 6-step 12-step
West
VAR 0.1241 ± 0.0234 0.4295 ± 0.1004 0.4820 ± 0.1298
MLP 0.1040 ± 0.0003 0.3742 ± 0.0238 0.4998 ± 0.0637
GRU 0.0913 ± 0.0047 0.1871 ± 0.0102 0.2707 ± 0.0006
RGN 0.0871 ± 0.0033 0.1708 ± 0.0024 0.2666 ± 0.0252
RGN(StandardOP) 0.0860 ± 0.0018 0.1674 ± 0.0019 0.2504 ± 0.0107
RGN(MeshOP) 0.0840 ± 0.0015 0.2119 ± 0.0018 0.4305 ± 0.0177
PA-DGN 0.0840 ± 0.0004 0.1614 ± 0.0042 0.2439 ± 0.0163
SouthEast
VAR 0.0889 ± 0.0025 0.2250 ± 0.0013 0.3062 ± 0.0032
MLP 0.0722 ± 0.0012 0.1797 ± 0.0086 0.2514 ± 0.0154
GRU 0.0751 ± 0.0037 0.1724 ± 0.0130 0.2446 ± 0.0241
RGN 0.0790 ± 0.0113 0.1815 ± 0.0239 0.2548 ± 0.0210
RGN(StandardOP) 0.0942 ± 0.0121 0.2135 ± 0.0187 0.2902 ± 0.0348
RGN(MeshOP) 0.0905 ± 0.0012 0.2052 ± 0.0012 0.2602 ± 0.0062
PA-DGN 0.0721 ± 0.0002 0.1664 ± 0.0011 0.2408 ± 0.0056
3.4.3 Contribution of Spatial Derivatives
We further investigate if the modulated spatial derivatives (Eq 3.1) are effectively advantageous compared
to the spatial derivatives defined in Riemannian manifolds. First, RGN without any spatial derivatives
is assessed for the prediction tasks on Western and Southeastern states graph signals. Note that this
model does not use any extra features but the graph signal, f(t). Secondly, we add (1) StandardOP, the
discrete spatial differences (Gradient and Laplacian) in Section 3.2.1 and (2) MeshOP, the triangular mesh
28
approximation of differential operators in Section 3.2.2 separately as additional signals to RGN. Finally, we
incorporate with RGN our proposed Spatial Difference Layer.
Table 3.4 shows the contribution of each component. As expected, PA-DGN provides much higher
drops in MAE (3.56%, 5.50%, 8.51% and 8.73%, 8.32%, 5.49% on two datasets, respectively) compared to
RGN without derivatives and the results demonstrate that the derivatives, namely, relative differences
from neighbor signals are effectively useful. However, neither RGN with StandardOP nor with MeshOP
can consistently outperform RGN. We also found that PA-DGN consistently shows positive effects on the
prediction error compared to the fixed derivatives. This finding is a piece of evidence to support that the
parameters modulating spatial derivatives in our proposed Spacial Difference Layer are properly inferred
to optimize the networks and to improve the prediction performance.
3.5 Evaluation on NEMO sea surface temperature (SST) dataset
We tested our proposed method and baselines on the NEMO sea surface temperature (SST) dataset∗
. We
first download the data in the area between 50N◦
-65N◦
and 75W◦
-10W◦
starting from 2016-01-01 to
2017-12-31, then we crop the [0, 550] × [100, 650] square from the area and sample 250 points from the
square as our chosen dataset. We divide the data into 24 sequences, each lasting 30 days, and truncate the
tail. All models use the first 5-day SST as input and predict the SST in the following 15 and 25 days. We use
the data in 2016 for training all models and the left for testing.
For StandardOP, MeshOP and SDL, we test both options using linear regression and using RGN for
the prediction part and report the best result. The results in Table 3.5 show that all methods incorporating
spatial differences gain improvement on prediction and that our proposed learnable SDL outperforms all
other baselines.
∗Available at http://marine.copernicus.eu/services-portfolio/access-to-products/?option=com_csw&view=
details&product_id=GLOBAL_ANALYSIS_FORECAST_PHY_001_024.
29
Table 3.5: Mean absolute error (10−2
) for SST graph signal prediction.
VAR MLP GRU RGU StandardOP MeshOP SDL
15-step 15.123 15.058 15.101 15.172 14.756 14.607 14.382
25-step 19.533 19.473 19.522 19.705 18.983 18.977 18.434
3.6 Ablation Study
3.6.1 Effect of different graph structures
In this section, we evaluate the effect of 2 different graph structures on baselines and our models: (1)
k-NN: a graph constructed with k-NN algorithm (k = 4); (2) TriMesh: a graph generated with Delaunay
Triangulation. All graphs use the Euclidean distance.
Table 3.6: Mean absolute error (10−2
) for graph signal prediction on the synthetic dataset.
VAR MLP StandardOP MeshOP SDL
k-NN TriMesh k-NN TriMesh k-NN TriMesh
17.30 16.27 12.00 12.29 12.87 12.82 11.04 12.40
Table 3.6 and Table 3.7 show the effect of different graph structures on the synthetic dataset used in
Section 3.3.2 and the real-world dataset in Section 3.4.2 separately. We find that for different models the
effect of graph structures is not homogeneous. For RGN and PA-DGN, k-NN graph is more beneficial
to the prediction performance than TriMesh graph, because these two models rely more on neighboring
information and a k-NN graph incorporates it better than a Delaunay Triangulation graph. However,
switching from TriMesh graph to k-NN graph is harmful to the prediction accuracy of RGN(MeshOP) since
Delaunay Triangulation is a well-defined method for generating triangulated mesh in contrast to k-NN
graphs. Given the various effect of graph structures on different models, our proposed PA-DGN under
k-NN graphs always outperforms other baselines using any graph structure.
30
Table 3.7: Graph signal prediction results (MAE) on multistep predictions. In each row, we report the
average with standard deviations from all baselines and PA-DGN. One step is 1 hour time interval.
Region Method Graph 1-step 6-step 12-step
West
VAR - 0.1241 ± 0.0234 0.4295 ± 0.1004 0.4820 ± 0.1298
MLP - 0.1040 ± 0.0003 0.3742 ± 0.0238 0.4998 ± 0.0637
GRU - 0.0913 ± 0.0047 0.1871 ± 0.0102 0.2707 ± 0.0006
RGN k-NN 0.0871 ± 0.0033 0.1708 ± 0.0024 0.2666 ± 0.0252
TriMesh 0.0897 ± 0.0030 0.1723 ± 0.0116 0.2800 ± 0.0414
RGN
(StandardOP)
k-NN 0.0860 ± 0.0018 0.1674 ± 0.0019 0.2504 ± 0.0107
TriMesh 0.0842 ± 0.0011 0.1715 ± 0.0027 0.2517 ± 0.0369
RGN
(MeshOP)
k-NN 0.0840 ± 0.0015 0.2119 ± 0.0018 0.4305 ± 0.0177
TriMesh 0.0846 ± 0.0017 0.2090 ± 0.0077 0.4051 ± 0.0457
PA-DGN k-NN 0.0840 ± 0.0004 0.1614 ± 0.0042 0.2439 ± 0.0163
TriMesh 0.0849 ± 0.0012 0.1610 ± 0.0029 0.2473 ± 0.0162
SouthEast
VAR - 0.0889 ± 0.0025 0.2250 ± 0.0013 0.3062 ± 0.0032
MLP - 0.0722 ± 0.0012 0.1797 ± 0.0086 0.2514 ± 0.0154
GRU - 0.0751 ± 0.0037 0.1724 ± 0.0130 0.2446 ± 0.0241
RGN k-NN 0.0790 ± 0.0113 0.1815 ± 0.0239 0.2548 ± 0.0210
TriMesh 0.0932 ± 0.0105 0.2076 ± 0.0200 0.2854 ± 0.0211
RGN
(StandardOP)
k-NN 0.0942 ± 0.0121 0.2135 ± 0.0187 0.2902 ± 0.0348
TriMesh 0.0868 ± 0.0132 0.1885 ± 0.0305 0.2568 ± 0.0328
RGN
(MeshOP)
k-NN 0.0913 ± 0.0016 0.2069 ± 0.0031 0.2649 ± 0.0092
TriMesh 0.0877 ± 0.0020 0.2043 ± 0.0026 0.2579 ± 0.0057
PA-DGN k-NN 0.0721 ± 0.0002 0.1664 ± 0.0011 0.2408 ± 0.0056
TriMesh 0.0876 ± 0.0096 0.2002 ± 0.0163 0.2623 ± 0.0180
3.6.2 Evaluation on datasets with different sparsity
We changed the number of nodes to control the sparsity of data. As shown in Table 3.8, our proposed model
outperforms others under various settings of sparsity on the synthetic experiment in Section 3.3.2.
Furthermore, we sampled 400 points and trained SDL as described in Section 3.3.1, and resampled fewer
points (350,300,250,200) to evaluate if SDL generalizes less sparse setting. As Table 3.9 shows, MSE increases
when fewer sample points are used. However, SDL is able to provide much more accurate gradients even
if it is trained under a new graph with different properties. Thus, the results support that SDL is able to
generalize the c setting.
31
Table 3.8: Mean absolute error (10−2
) for graph signal prediction with different sparsity.
#Nodes VAR MLP StandardOP MeshOP SDL
250 0.1730 0.1627 0.1200 0.1287 0.1104
150 0.1868 0.1729 0.1495 0.1576 0.1482
100 0.1723 0.1589 0.1629 0.1696 0.1465
Table 3.9: Mean squared error (10−2
) for approximations of directional derivatives of function f2(x, y) =
sin (x) + cos (y) with different sparsity.
Method 350 Nodes 300 Nodes 250 Nodes 200 Nodes
FinGrad 2.88 ± 0.11 3.42 ± 0.14 3.96 ± 0.17 4.99 ± 0.31
SDL 1.03 ± 0.09 1.14 ± 0.12 1.40 ± 0.10 1.76 ± 0.10
3.7 The distribution of prediction error across nodes
124 122 120 118 116 114
32.5
35.0
37.5
40.0
42.5
45.0
47.5
MAE distribution for 1-step prediction
124 122 120 118 116 114
32.5
35.0
37.5
40.0
42.5
45.0
47.5
MAE distribution for 6-step prediction
124 122 120 118 116 114
32.5
35.0
37.5
40.0
42.5
45.0
47.5
MAE distribution for 12-step prediction
0.050
0.075
0.100
0.125
0.150
0.175
0.200
0.225
0.05
0.10
0.15
0.20
0.25
0.30
0.35
0.10
0.15
0.20
0.25
0.30
0.35
0.40
Figure 3.7: MAE across the nodes.
Figure 3.7 provides the distribution of MAEs across the nodes of PA-DGN applied to the graph signal
prediction task of the west coast region of the real-world dataset in Section 3.4.2. As shown in the figure,
nodes with the highest prediction error for short-term prediction are gathered in the inner part where
the observable nodes are sparse, while for long-term prediction nodes in the area with a limited number
of observable points no longer have the largest MAE. This implies that PA-DGN can utilize neighboring
information efficiently even under the limitation of sparsely observable points.
32
3.8 Conclusion
In this paper, we introduce a novel architecture (PA-DGN) that approximates spatial derivatives to use them
to represent PDEs which have a prominent role for physics-aware modeling. PA-DGN effectively learns the
modulated derivatives for predictions and the derivatives can be used to discover hidden physics describing
interactions between temporal and spatial derivatives.
33
Chapter 4
Physics-Informed Long-Sequence Forecasting From Multi-Resolution
Spatiotemporal Data
4.1 Introduction
Forecasting from spatiotemporal data has wide applications in domains such as transportation and energy.
The input and output of such forecasting tasks are both sequences of graph signal frames, where each
frame is a graph with multivariate features defined on nodes and edges. In these applications, long-run
forecasting is usually required for planning and policy making. As illustrated by [113], forecasting long
sequences into the long-run future has higher requirements on the long-range alignment ability of models
compared to short-run forecasting tasks.
One key aspect for forecasting long sequences is effectively modeling the complex patterns and dynamics
of real-world spatiotemporal data aggregated in various spatial and temporal resolutions. Figure 4.1 shows
an example from the taxi demand dataset. Data aggregated in different spatial (first 4 rows vs. the last row)
and temporal (left/middle/right columns) resolutions demonstrate correlated but heterogeneous patterns,
implying the necessity of refined modeling of multi-resolution data. In practical cases where data of high
resolutions usually suffer from high missing rates and low signal-to-noise ratios due to the high cost of
34
Figure 4.1: Example of multi-resolution spatiotemporal data: Taxi Pickup Rates. Each of the first 4 rows
displays the change of taxi pickup rates (total pickup times in 30 minutes) over a one week period for one
of the 4 pink taxi zones (bounded with black in the map).The last row shows the aggregated (summed) taxi
pickup rates in the borough composed of the 4 pink taxi zones, and thus is of a coarse spatial resolution.
Each column stands for one temporal aggregation resolution. Data in coarser spatial/temporal resolutions
is aggregated from data in finer resolutions when the latter is fully observed. In other scenarios, it can also
be collected from different data sources and thus have heterogeneous qualities.
collection, correctly capturing the interaction among data in various resolutions is even more critical for
forecasting.
While existing works have achieved great success in short-run predictions of uni-resolution spatiotemporal data within several hours into the future [109, 22, 105, 102, 112, 14, 116, 11], research on long-run
forecasting of spatiotemporal data remains hardly developed. [113] proposes an efficient Transformerbased [96] architecture for long-sequence time series forecasting, but still lacks modeling of spatial relations
and ignores the inter-resolution dynamics.
To bridge the gap between multi-resolution spatiotemporal data and long-sequence forecasting, the
proposed model must address challenges in two aspects: (1) Inter-resolution modeling. The model must be
able to capture interactions of data among resolutions to fully utilize them for forecasting. (2) Intra-resolution
modeling. The model should effectively extract information representing the dynamics from either short or
long sequences of each resolution.
35
In this paper, we propose Spatiotemporal Koopman Multi-Resolution Network (ST-KMRN) to address
both challenges. For better inter-resolution modeling, we leverage the self-attention mechanism to fuse and
communicate representations of input data in all temporal resolutions. The model is then jointly trained
with forecasting losses over all spatial and temporal resolutions. We further construct downsampling
and upsampling modules among forecasting outcomes of different temporal resolutions to enhance interresolution connections. To improve intra-resolution modeling, we combine the Koopman theory-based
modeling of dynamic systems with deep learning-based prediction modules for more principled modeling
of dynamics. The inferred Koopman matrix offers interpretations of heterogeneous dynamics in different
resolutions at the same time.
Our contributions are: (1) ST-KMRN captures inter-resolution dynamics of spatiotemporal data via
connecting representations in different levels of resolutions via self-attention, providing an effective
way of leveraging multi-resolution data. (2) ST-KMRN further improves inter-resolution modeling with
upsampling and downsampling modules among predictions of various resolutions. (3) ST-KMRN improves
the modeling of intra-resolution dynamics with the combination of Koopman theory-based modeling
and deep learning-based forecasting models, which also provides interpretable information of different
dynamics in multi-resolution data. (4) ST-KMRN achieves state-of-the-art performance on the long-sequence
forecasting tasks from real-world spatiotemporal datasets.
4.2 Problem Formulation
In this work, we focus on the task of forecasting spatiotemporal multivariate time series with multiple
resolutions of both space and time. We denote the set of regions by S = {s1, s2, . . . , sN }, the set of P
historical time steps by TH = {t1, t2, . . . , tP }, the set of Q future time steps by TF = {tP +1, . . . , tP +Q},
and the multivariate variable of dimension D at region s and time step t by xt,s ∈ R
D.
36
Definition 1 (Resolution). A resolution R(X) is a partition of the set X. In addition, for a temporal resolution R(T) (T = TH or T = TF ), we require |T| = |R(T)| × r, ∃r ∈ N, and R(T) =
{{t1, . . . , tr}, . . . , {t(|R(T)|−1)×r+1, . . . , t|R(T)|×r}}. We name r as the scale of the temporal resolution
R(T).
Definition 2 (Aggregated Variable). The aggregated multivariate variable over a set of regions S
′
and
a set of time steps T
′
is defined as x
agg
T′
,S
′ = agg({xt,s | t ∈ T
′
, s ∈ S
′}), where agg is some aggregation
function agg : R
N×D → R
D, such as summation or average. We omit agg in following notations since the
aggregation function is usually the same for a given dataset and can be inferred from the context.
Definition 3 (Observation at Given Spatial and Temporal Resolutions). The observation at a given
spatial resolution R(S) and a given temporal resolution R(T) is defined as a set of aggregated multivariate
variables: OR(T),R(S) = {xrT ,rS
|rT ∈ R(T), rS ∈ R(S)} ∈ R
|R(T)|×|R(S)|×D.
We then formulate the input and output of the forecasting problem with multi-resolution spatiotemporal
input data as follows:
Input (1) The set of regions S, the set of historical time steps TH, and the set of future time steps TF . (2)
The list of available temporal resolutions TR = {R
T
1
(TH), . . . , R
T
RT
(TH)}, and the list of available spatial
resolutions SR = {R
S
1
(S), . . . , R
S
RS
(S)}. R
H
1
(TH), R
H
1
(S) are the highest temporal and spatial resolution
respectively. In other words, |rt
| = 1, ∀ rt ∈ R
T
1
(TH), and |rs| = 1, ∀ rs ∈ R
S
1
(S). (3) The set of historical
spatio-temporal resolution pairs will be HR = TR × SR. (4) The set of historical observations in multiple
spatial and temporal resolutions O = {OT R,SR | (T R, SR) ∈ HR}.
Output The set of forecast values at the target spatial and temporal resolution Yˆ = ORT
out(TF ),RS
out(S)
.
37
(a) Overview of ST-KMRN’s architecture (RT = 3). (b) Neural network based module.
Figure 4.2: ST-KMRN’s architecture.
4.3 Spatiotemporal Koopman Multi-Resolution Network
In this section, we introduce our proposed model: Spatiotemporal Koopman Multi-Resolution Network (STKMRN). Figure 4.2a provides an overview of its architecture.
Input historical observations O in multiple spatial and temporal resolutions are first grouped by
temporal resolutions into {OT R1
, . . . , OT RRT
}, where OT Ri = concat({OT Ri,SR1
, . . . , OT Ri,SRRS
}) ∈
R
|T Ri|×NS×D, NS =
PRS
k=1 |SRk| is the concatenated multivariate time series along the spatial dimension.
For OT Ri
, we construct a spatial hierarchical graph Gi = (Vi
, Ei), where Vi =
SRS
k=1 SRk contains region
sets at all spatial resolution levels and Ei
is composed of 2 types of edges: (1) intra-spatial-resolution edges
constructed with prior knowledge such as spatial information or connection strengths; (2) inter-spatialresolution edges between all pairs of region sets (Sp, Sq) iff. Sp ⊂ Sq.
Each pair of (OT Ri
, Gi) is fed into the i-th spatiotemporal encoder to generate a temporal-resolutionspecific embedding Ei ∈ R
NS×H, where H is the embedding dimension. Embeddings of all temporal
resolutions are then propagated and updated with the self-attention module and the updated embedding E′
i
serves as the input of the i-th decoder to get the deep learning based forecasting result Yˆ
DL
T Ri
. Meanwhile,
a Koopman Theory based forecasting module takes OT Ri
as its input and gives its forecasting Yˆ
K
T Ri
in
parallel. Yˆ
DL
T Ri
and Yˆ
K
T Ri
are then fused with a gating mechanism to get the first-stage forecasting result
Yˆ
1
T Ri
. We further design upsampling and downsampling modules for the second-stage forecasting: Yˆ
1
T Ri
,
38
Yˆ
ds
T Ri = DownSamplingi−1
(Yˆ
1
T Ri−1
), Yˆ
ups
T Ri = UpSamplingi+1(Yˆ
1
T Ri+1 ) is combined with the attention
mechanism for the final forecast Yˆ
T Ri
(for i = 1 / N, Yˆ
ds
T Ri−1
/ Yˆ
ups
T Ri+1 is omitted respectively).
4.3.1 Physics-Informed Modeling of Intra-Resolution Dynamics
In this part, we introduce our proposed physics-informed modeling of intra-resolution dynamics, where
a neural network based module and a Koopman theory based module are combined for encoding and
forecasting in each temporal resolution.
4.3.1.1 Neural Network Based Module (ST-Encoder and ST-Decoder)
The architecture of the neural network based module is a modification of Graph-WaveNet [102]. Figure 4.2b
shows the architecture of the encoder part, where multiple temporal convolutional layers (TCN) [73] and
graph convolutional layers (GCN) [41] are stacked alternately to encode spatial dependencies at various
temporal scales. The output of the encoder is an embedding vector Ei summarizing the input sequence
OT Ri
. Ei
is then propagated with embeddings from other resolutions into E′
i
via self-attention [96]. The
decoder is composed of two 1D convolutional layers applied along the spatial dimension, which maps E′
i
to the forecasting result Yˆ
DL
T Ri
.
4.3.1.2 Koopman Module
Koopman Theory Given a non-linear dynamic system with its state vector at time t denoted as st ∈ R
m.
The system can be described as st+1 = F(st). As defined in [43], the Koopman operator KF is a linear
transformation defined on a function space F by KF g = g ◦ F for every g : R
m → R that belongs to the
infinite-dimensional Hilbert space F. With this definition, we have KF g(st) = g ◦ F(st) = g(st+1).
The Koopman theory [43] guarantees the existence of K, but in practice we often assume the existence
of an invariant finite-dimensional subspace G of F spanned by k bases {g1, g2, . . . , gk}. Define gt =
39
[g1(st), g2(st), . . . , gk(st)]T
and gt+1 = [g1(st+1), g2(st+1), . . . , gk(st+1)]T
, under the assumption we
have gt
, gt+1 ∈ G and there exists a Koopman matrix K ∈ R
k×k
s.t. gt+1 = Kgt
.
Koopman Theory Based Modeling of Temporal Dynamics Compared to directly modeling the state
propagation function F with neural networks, the Koopman theory provides prior knowledge related to
the system. In addition, the resulting Koopman matrix is more interpretable as the propagation of hidden
states can be represented as a linear mapping.
The key problem is to find the pair of mappings between the state space R
m and the invariant subspace
G ∈ R
k
: g : R
m → R
k
and g
−1
: R
k → R
m. Here we model them with 2 multi-layer perceptrons (MLPs).
g maps each frame of the input OT Ri
to the subspace, and generates HT Ri ∈ R
|T Ri|×k
. Then, the Koopman
matrix KT Ri
is estimated as
arg min
K
K(HT Ri
[0 : −1])T − (HT Ri
[1 :])T
2
2
. (4.1)
With HT Ri
[−1] as the starting state q0 for forecasting, the module forecasts Ti steps into the future in a
recurrent way: Yˆ
K
T Ri = [g
−1
(KT Riq0), g
−1
(K2
T Ri
q0), . . . , g
−1
(KTi
T Ri
q0)].
4.3.1.3 Gated Fusion of ST-Encoder and Koopman Module
While the Koopman theory based module is more principled and can be incorporated with prior knowledge
of the system, the neural network module is far more expressive and flexible. To combine the advantages of
both modules, we design a fusion module that jointly leverages their forecasting. More specifically, we use
40
a 2-layer 1D convolutional neural network (CNN) with the sigmoid activation function to generate a vector.
The vector controls the ratio of each module’s prediction in the combined forecasting results as follows:
ηT Ri = fGate(concat(Yˆ
DL
T Ri
, Yˆ
K
T Ri
)) (4.2)
Yˆ
1
T Ri = (1 − ηT Ri
) ⊗ Yˆ
DL
T Ri + ηT Ri ⊗ Yˆ
K
T Ri
, (4.3)
where ⊗ is the element-wise product.
4.3.2 Inter-Resolution Dynamics Modeling
In this part, we introduce the self-attention module that fuses representations over all temporal resolutions
as well as the upsampling and downsampling modules for the second stage forecasting.
4.3.2.1 Self-Attention Module
The self-attention mechanism, proposed in [96], captures dependencies of input elements without regard
to their distance in the input sequence. With similar motivation, we adopt self-attention to model the
interaction of representations from multiple temporal resolutions. Given the list of resolution-specific
embeddings [E1, E2, . . . , ERT
], we concatenate all embeddings to construct the input representation as
E ∈ R
NS×RT ×H. We then update the embeddings of all regions in parallel via multi-head self-attention
across resolutions.
4.3.2.2 Upsampling and Downsampling Modules
Upsampling Module We set up a learnable upsampling module UpSamplingi
for each Yˆ
1
T Ri
, i > 1.
Denote ri as the scale of the temporal resolution T Ri
. In our experimental settings we always have
ri = kiri−1, ∃ki ∈ N, ∀i > 1. We implement the i-th upsampling module as a 1D CNN with D input
41
channels and kiD output channels. Its output is then rearranged as the upsampled prediction in the same
way as PixelShuffle [90] for the upsampled forecasting results Yˆ
ups
T Ri−1
.
Downsampling Module We also set up a downsampling module DownSamplingi
for each Yˆ
1
T Ri
, i < RT .
In the same experimental settings as mentioned above, we implement the downsampling module as an
aggregation operator, which is of the same type of aggregation for generating data in coarser temporal
resolutions. For the temporal resolution T Ri
, i < RT , the operator aggregates the prediction in Yˆ
1
T Ri
every
ri+1/ri steps as the downsampled prediction Yˆ
ds
T Ri+1 .
Weighted Summation of Forecasting Results For the i-th temporal resolution (1 < i < RT ), we use
the weighted summation of the first-stage results, the upsampled results, and the downsampled results as
the second-stage prediction. Weights for summation are the output of a learnable module implemented as
CNN with softmax activation function defined for the i-th resolution:
wT Ri = fweight(concat(Yˆ
1
T Ri
, Yˆ
ups
T Ri
, Yˆ
ds
T Ri
)) (4.4)
Yˆ
T Ri = wT Ri
[:, :, 0] ⊗ Yˆ
1
T Ri + wT Ri
[:, :, 1] ⊗ Yˆ
ups
T Ri
+wT Ri
[:, :, 2] ⊗ Yˆ
ds
T Ri
,
(4.5)
where ⊗ is the element-wise product. For cases where i = 1[resp. RT ], the weighted summation is
calculated in the same way but with the Yˆ
ds
T Ri
[resp. Yˆ
ups
T Ri
] term removed.
42
4.3.3 Loss Function
Denote the forecasting loss function as L(Yˆ
pred, Ytrue), the total loss for training the model is
Ltotal =
X
RT
i=1
LT Ri
, where (4.6)
LT Ri =
X
Yˆ∈
Yˆ
DL
T Ri
, Yˆ
K
T Ri
,
Yˆ
1
T Ri
, Yˆ
ups
T Ri
,
Yˆ
ds
T Ri
, Yˆ
T Ri
L(Yˆ, YT Ri
) + L
(KM)
T Ri
(4.7)
L
(KM)
T Ri
=
P
1⩽p,q⩽|T Ri|
∥|g(YT Ri
[p]) − g(YT Ri
[q])|
− |YT Ri
[p] − YT Ri
[q]|∥ .
(4.8)
L
(KM)
T Ri
encourages the mapping from the state space to the invariant subspace in the Koopman module to
preserve the distance.
4.4 Experiments
Datasets We evaluate the performance of ST-KMRN and all baselines on 3 datasets: (1) New York
Yellow Taxi Trip Record Data (YellowCab) [72] in 2017-2019; (2) New York Green Taxi Trip Record Data
(GreenCab) [72] in 2017-2019; and (3) Solar Energy Data (Solar Energy) [71] of Alabama in 2006. We use
sliding windows to generate input/output sequence pairs ordered by starting time and divide all pairs into
train/validation/test sets with the ratio 60%/20%/20%. Since all datasets are originally in a single resolution,
we construct one extra spatial resolution (named as "agg regions") and two extra temporal resolutions on
all datasets by aggregation to simulate the multi-resolution scenario.
Here we introduce the method to construct multi-resolution datasets used in our experiments from raw
data in a single resolution.
43
YellowCab & GreenCab We construct the YellowCab and GreenCab datasets with the NYCTaxi Trip
Record Data∗ between the year 2017 and 2019, which contains the yellow and green taxi trip records
including fields of pick-up/drop-off times and locations. The pick-up/drop-off times are precise, while the
locations are discrete values from a pre-defined set of regions†
.
To construct one dataset, we first divide the whole time range (3 years) into time windows of equal
size, then for each time window and each region, we count the total number of trips starting/ending within
the time window and the region as the pick-up/drop-off numbers. The taxi demand forecasting task is
to predict the pick-up and drop-off numbers of each time window and each region into the future. We
construct two datasets with the yellow and green taxi trip records respectively as YellowCab and GreenCab.
Following the practice in [105], we choose a 30-minute time window as the finest temporal resolution.
We construct the spatial graph among regions as follows: an undirected edge (i, j) exists iff the average
number of trips between the i-th region and the j-th region is at least 1 in 6 hours. The graph contains
1607/374 undirected edges among regions for YellowCab/GreenCab respectively.
We choose [6-hour, 1-day] as additional coarser temporal resolutions, and the New York City United
Hospital Fund Neighborhoods (UHF42) as an additional coarser spatial resolution. Figure 4.3a and Figure 4.3b
visualize the two levels of spatial resolutions of YellowCab and GreenCab.
Solar Energy We construct the Solar Energy dataset with National Renewable Energy Laboratory
(NREL)’s solar photovoltaic (PV) power plant data points for the state of Alabama in the United States
representing the year 2006‡
. The dataset contains 5-minute power output of 137 solar power plants in
Alabama. We follow the practice in [45] and select the 10-minute time window as the finest temporal
resolution: the whole time range is divided into 10-minute time windows and the value of each time window
is the summation of values falling into it from the original data.
∗
https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page
†
https://s3.amazonaws.com/nyc-tlc/misc/taxi+_zone_lookup.csv
‡
https://www.nrel.gov/grid/assets/downloads/al-pv-2006.zip
44
We build the spatial graph’s adjacency matrix among regions (power plants) using the Gaussian kernel
with a threshold: Wi,j = di,j if di,j >= κ else 0, where di,j = exp (−
dist(vi,vj )
2
σ2 ), dist(vi
, vj ) is the straight
line distance from plant vi to plant vj , σ is the standard deviation of distances and κ is the threshold. We set
κ = 0.95 for the Solar Energy dataset. The constructed graph has 967 directed edges (including self loops).
We choose [1-hour, 6-hour] as additional coarser temporal resolutions, and we cluster all plants into 10
groups as the additional coarser spatial resolution using KMeans. Figure 4.3c shows the spatial resolutions
of Solar Energy.
(a) YellowCab (b) GreenCab (c) Solar Energy
Figure 4.3: Visualization of spatial resolutions. In Figure 4.3a and Figure 4.3b, each block bounded with
black is one taxi zone in the finer spatial resolution while blocks with the same color is one UHF42 region
in the coarser spatial resolution. The geometric centers of blocks are used to build graphs. In Figure 4.3c,
each point is the location of one power plant in the finer spatial resolution, and points with the same color
is one cluster in the coarser spatial resolution. Stars are clustering centers.
Table 4.1 shows statistics and task settings of each dataset. We use sliding windows with strides to
generate input/output sequence pairs ordered by starting time. Each pair is one data sample. Then we
divide all samples into train/validation/test sets with the ratio 60%/20%/20%. Table 4.2 shows details of data
splits of each dataset.
Baselines We compare our model ST-KMRN with the following baselines: (1) Historical Averaging (HA):
We use the averaged value of historical frames with the 1-week period as the prediction. (2) Static: We
use the value from the last available frame in the input sequence with the 1-week period as the prediction.
45
Table 4.1: Datasets and tasks.
YellowCab GreenCab Solar Energy
time step size 30min 30min 10min
# time steps 52560 52560 52560
temporal
resolutions [30min, 6h, 1d] [30min, 6h, 1d] [10min, 1h, 6h]
# spatial regions 67 76 137
# agg regions 11 12 10
input seq lengths [1440, 120, 30] [1440, 120, 30] [1440, 240, 40]
target temporal
resolution 30min 30min 10min
output seq length 480 480 432
max output horizon 10 days 10 days 3 days
target spatial
resolution spatial regions spatial regions agg regions
Table 4.2: Details of data splits.
YellowCab GreenCab Solar Energy
time step size 30min 30min 10min
# stride 48 (24 hours) 48 (24 hours) 36 (6 hours)
# train samples 618 618 825
# valid samples 210 210 281
# test samples 210 210 281
46
Table 4.3: Forecasting results within the maximum possible horizon from fully observed input sequences.
The lowest error is marked in bold and the second-lowest error in italic with underline. The row "RelErr"
shows the relative error change to the best baseline model and "RelErrGW" to the Graph WaveNet baseline.
Data YellowCab GreenCab Solar Energy
Metric MAE RMSE MAE RMSE MAE RMSE
HA 21.934 35.473 3.764 5.701 69.401 152.590
Static 14.004 26.324 2.076 3.080 71.758 179.680
GRU 25.625(0.188) 38.207(0.180) 2.668(0.056) 3.894(0.082) 114.667(9.018) 201.622(10.462)
Informer 22.210(2.046) 33.484(2.722) 1.926(0.062) 2.818(0.079) 81.456(5.181) 171.407(10.061)
Graph WaveNet 16.889(0.115) 30.929(0.257) 1.801(0.011) 2.787(0.019) 135.644(0.062) 290.471(0.186)
MTGNN 18.914(0.592) 34.017(0.985) 2.222(0.022) 3.462(0.009) 130.357(3.472) 283.506(5.695)
KoopmanAE 16.922(0.680) 28.115(0.891) 2.612(0.131) 3.832(0.196) 148.614(2.297) 253.097(0.421)
ST-KMRN 13.631(0.812) 24.090(1.054) 1.682(0.004) 2.479(0.017) 67.718(1.620) 148.304(2.279)
RelErr -2.7% -8.5% -6.6% -11.1% -2.4% -2.8%
RelErrGW -19.3% -22.1% -6.6% -11.1% -50.1% -48.9%
(3) Gated Recurrent Unit (GRU) [17] (4) Informer [113] (5) Graph WaveNet [102] (6) MTGNN [101] (7)
KoopmanAE [3]. For all baselines, we concatenate input in multiple temporal resolutions as features into
one single temporal resolution. We have also considered other recent works on short-term spatiotemporal
forcasting as baselines, such as AGCRN [4], GMAN [112], DGCRN [47]. However, their designs induce
high computation and memory complexity on long sequential data and prohibit us from retrieving results
practically.
4.4.1 Long-Sequence Forecasting With Fully Observed Input
Set-up We train all models and evaluate their forecasting performance within certain horizons into the
future. We repeat each experiment 3 times and report mean values and standard deviations of all metrics.
Discussion Table 4.3 shows the forecasting performance within the maximum possible horizon (10-day
for YellowCab/GreenCab and 3-day for Solar Energy) of baselines and our proposed ST-KMRN measured in
Mean Absolute Error (MAE) and Rooted Mean Squared Error (RMSE). Table 4.4 demonstrates results of
forecasting from fully observed data within multiple horizons.
(1) On all of 3 datasets, ST-KMRN outperforms baselines in most horizons, demonstrating the advantage
of ST-KMRN in long-sequence forecasting with multi-resolution data. (2) ST-KMRN achieves significantly
47
Table 4.4: Forecasting results within multiple horizons from fully observed input sequences. The lowest
error is marked in bold and the second-lowest error in italic with underline. The row "RelErr" shows the
relative error change to the best baseline model and "RelErrGW" to the Graph WaveNet baseline.
(a) YellowCab
Horizon 30min 6h 1d 10d
Metric MAE RMSE MAE RMSE MAE RMSE MAE RMSE
HA 19.43 30.04 10.71 19.89 21.57 34.68 21.93 35.47
Static 12.50 21.87 6.95 14.78 13.45 25.00 14.00 26.32
GRU 22.69(1.85) 32.98(3.09) 14.75(0.67) 21.99(1.13) 25.57(0.15) 38.34(0.52) 25.62(0.19) 38.21(0.18)
Informer 20.24(1.15) 28.78(1.77) 14.45(0.62) 23.11(0.68) 21.49(1.83) 32.10(2.36) 22.21(2.05) 33.48(2.72)
Graph WaveNet 20.26(0.84) 32.42(0.28) 11.98(0.22) 26.11(0.46) 16.45(0.08) 30.16(0.13) 16.89(0.12) 30.93(0.26)
MTGNN 22.26(1.17) 38.32(1.76) 12.37(0.49) 28.58(0.79) 18.28(0.63) 33.22(0.99) 18.91(0.59) 34.02(0.98)
KoopmanAE 15.45(0.35) 24.11(0.65) 10.31(0.34) 16.92(0.75) 16.71(0.50) 27.20(0.65) 16.92(0.68) 28.11(0.89)
ST-KMRN 12.27(0.64) 18.06(0.47) 7.53(0.60) 13.73(0.99) 13.32(0.92) 23.07(1.32) 13.63(0.81) 24.09(1.05)
RelErr -1.8% -17.4% 8.3% -7.1% -1.0% -7.7% -2.6% -8.5%
RelErrGW -39.4% -44.3% -37.1% -47.4% -19.0% -23.5% -19.3% -22.1%
(b) GreenCab
Horizon 30min 6h 1d 10d
Metric MAE RMSE MAE RMSE MAE RMSE MAE RMSE
HA 4.37 6.62 3.25 5.67 3.74 5.67 3.76 5.70
Static 1.95 2.79 1.58 2.29 2.07 3.06 2.08 3.08
GRU 3.54(0.19) 5.17(0.32) 2.33(0.14) 3.52(0.23) 2.67(0.06) 3.89(0.09) 2.67(0.06) 3.89(0.08)
Informer 1.80(0.05) 2.60(0.05) 1.51(0.07) 2.18(0.08) 1.84(0.05) 2.73(0.07) 1.93(0.06) 2.82(0.08)
Graph WaveNet 1.79(0.03) 2.68(0.03) 1.54(0.04) 2.43(0.05) 1.80(0.01) 2.78(0.02) 1.80(0.01) 2.79(0.02)
MTGNN 2.26(0.02) 3.33(0.01) 1.87(0.01) 2.94(0.02) 2.22(0.02) 3.46(0.01) 2.22(0.02) 3.46(0.01)
KoopmanAE 2.35(0.32) 3.30(0.41) 1.65(0.04) 2.36(0.09) 2.79(0.09) 4.12(0.14) 2.61(0.13) 3.83(0.20)
ST-KMRN 1.61(0.04) 2.37(0.05) 1.30(0.03) 2.04(0.07) 1.65(0.01) 2.46(0.01) 1.68(0.00) 2.48(0.02)
RelErr -10.1% -8.8% -13.9% -6.4% -8.3% -9.9% -6.7% -11.1%
RelErrGW -10.1% -11.6% -15.6% -16.0% -8.3% -11.5% -6.7% -11.1%
(c) Solar Energy
Horizon 10min 1h 6h 3d
Metric MAE RMSE MAE RMSE MAE RMSE MAE RMSE
HA 53.1 143.3 62.5 147.6 68.7 151.9 69.4 152.6
Static 63.1 188.3 63.8 185.7 71.0 179.2 71.8 179.7
GRU 93.2(3.9) 148.0(4.5) 106.1(4.1) 169.4(4.6) 110.1(12.2) 195.6(14.8) 114.7(9.0) 201.6(10.5)
Informer 68.8(4.1) 148.7(6.1) 62.5(1.2) 140.2(5.7) 79.1(5.0) 161.6(11.0) 81.5(5.2) 171.4(10.1)
Graph WaveNet 124.9(0.2) 310.1(0.2) 123.0(0.3) 301.4(0.3) 136.3(0.1) 291.5(0.3) 135.6(0.1) 290.5(0.2)
MTGNN 117.2(5.8) 296.9(10.5) 115.0(5.7) 288.9(10.2) 129.7(4.8) 282.6(7.3) 130.4(3.5) 283.5(5.7)
KoopmanAE 151.4(3.0) 276.0(0.7) 149.9(1.7) 267.9(0.5) 147.3(1.6) 253.0(0.3) 148.6(2.3) 253.1(0.4)
ST-KMRN 45.4(1.1) 118.9(1.9) 44.5(0.1) 121.9(1.0) 64.3(1.3) 140.1(2.0) 67.7(1.6) 148.3(2.3)
RelErr -14.5% -17.0% -28.8% -13.1% -6.4% -7.8% -2.4% -2.8%
RelErrGW -63.7% -61.7% -63.8% -59.6% -52.8% -51.9% -50.1% -49.0%
48
decreased forecasting errors compared to the Graph-WaveNet baseline, with whom its encoders and decoders
share similar neural network architectures. The results demonstrate that the gain of ST-KMRN comes
from the enhanced modeling of both intra-resolution and inter-resolution dynamics. (3) On the YellowCab
and Solar Energy datasets, our proposed ST-KMRN still holds the best performance compared to other
baselines, but its advantages over HA and Static are not as large as on the GreenCab dataset. This is due
to the strong periodicity within the fully observed input sequences. We further conduct experiments in a
more challenging setting: input data is partially observed, where the periodicity in input is weaker.
4.4.2 Long-Sequence Forecasting Results With Partially Observed Input
Set-up We evaluate the effect of missing data in input sequences by training and evaluating all models
with varying ratios of observable frames in the input of the finest spatial and temporal resolution (named
as “obs ratio”). This setting aims to simulate the practical scenario when the model needs to forecast with
low-quality input data of high resolutions.
Discussion Table 4.5 shows the 10-day forecasting performance of all models with various obs ratios
on YellowCab, GreenCab and Solar Energy. When the input data suffers from high missing ratios, the
periodicity of data is less beneficial for forecasting, and the capability of capturing non-periodic patterns
becomes more critical. Thus, trainable models start having advantages over the Static baseline. Under all obs
ratios we select (0.8/0.6/0.4/0.2), our proposed ST-KMRN achieves the lowest forecasting errors, especially
for higher ones (0.4 and 0.2), demonstrating its advantages with partially observed input.
Note that we here simulate the practical scenario where data of the highest temporal resolution and the
highest spatial resolution suffers from missing while data of other resolutions are fully observed. Since
the forecasting task from Solar Energy focuses on the lower spatial resolution, its input from the same
resolution is still fully observed and thus the results of HA and Static do not vary with the observation
ratios.
49
Table 4.5: Forecasting results with partially observed input (YellowCab, Horizon=10d).
(a) YellowCab
Obs Ratio 0.8 0.6 0.4 0.2
Metric MAE RMSE MAE RMSE MAE RMSE MAE RMSE
HA 21.95 35.52 21.97 35.58 21.99 35.62 22.08 35.82
Static 14.21 26.66 15.77 29.73 22.49 40.69 41.19 61.80
GRU 22.83(1.15) 34.97(1.62) 27.74(0.40) 41.21(0.46) 27.75(0.40) 41.27(0.45) 27.78(0.42) 41.31(0.46)
Informer 19.04(0.98) 29.21(1.29) 23.83(1.93) 35.35(2.50) 22.90(0.35) 34.31(0.50) 23.30(1.53) 34.66(1.96)
Graph WaveNet 16.53(0.29) 31.10(0.43) 17.08(0.30) 31.01(0.39) 16.74(0.21) 30.76(0.35) 16.75(0.33) 30.72(0.16)
MTGNN 18.43(0.88) 32.73(2.02) 19.41(0.40) 34.88(0.71) 19.81(0.09) 35.49(0.11) 18.95(0.11) 33.46(0.66)
KoopmanAE 18.43(1.57) 29.79(2.20) 19.90(1.47) 31.58(1.94) 18.96(2.08) 30.09(2.99) 19.07(0.58) 30.66(0.77)
ST-KMRN 12.56(0.66) 22.58(0.71) 11.66(0.36) 21.46(0.58) 11.38(0.28) 21.17(0.45) 11.76(0.31) 21.64(0.48)
RelErr -11.6% -15.3% -26.1% -27.8% -32.0% -29.6% -29.8% -29.4%
RelErrGW -24.0% -27.4% -31.7% -30.8% -32.0% -31.2% -29.8% -29.6%
(b) GreenCab
Obs Ratio 0.8 0.6 0.4 0.2
Metric MAE RMSE MAE RMSE MAE RMSE MAE RMSE
HA 3.76 5.70 3.77 5.71 3.77 5.73 3.79 5.78
Static 2.08 3.09 2.14 3.17 2.40 3.47 3.09 4.18
GRU 2.67(0.04) 3.90(0.07) 2.66(0.04) 3.89(0.07) 2.67(0.03) 3.90(0.05) 2.67(0.05) 3.91(0.08)
Informer 1.96(0.07) 2.88(0.10) 2.03(0.07) 2.95(0.08) 2.03(0.02) 2.99(0.07) 2.08(0.05) 3.00(0.07)
Graph WaveNet 1.77(0.01) 2.75(0.00) 1.79(0.03) 2.78(0.05) 1.80(0.02) 2.79(0.01) 1.82(0.01) 2.84(0.02)
MTGNN 2.09(0.11) 3.24(0.15) 2.21(0.08) 3.46(0.06) 2.21(0.04) 3.43(0.03) 2.23(0.01) 3.50(0.07)
KoopmanAE 2.78(0.55) 4.01(0.72) 2.94(0.44) 4.24(0.50) 2.97(0.45) 4.23(0.50) 2.95(0.50) 4.26(0.59)
ST-KMRN 1.73(0.02) 2.51(0.03) 1.72(0.04) 2.53(0.03) 1.68(0.03) 2.47(0.04) 1.68(0.04) 2.47(0.04)
RelErr -2.3% -8.7% -3.9% -9.0% -6.7% -11.5% -7.7% -13.0%
RelErrGW -2.3% -8.7% -3.9% -9.0% -6.7% -11.5% -7.7% -13.0%
(c) Solar Energy
Obs Ratio 0.8 0.6 0.4 0.2
Metric MAE RMSE MAE RMSE MAE RMSE MAE RMSE
HA 200.7 255.2 200.7 255.2 200.7 255.2 200.7 255.2
Static 261.1 409.7 261.1 409.7 261.1 409.7 261.1 409.7
GRU 118.3(8.7) 204.2(11.3) 118.5(12.8) 206.0(14.8) 113.9(8.7) 200.4(11.0) 111.8(6.0) 192.1(5.2)
Informer 91.5(3.0) 186.5(7.2) 90.8(1.0) 182.0(3.8) 86.7(5.3) 177.4(8.1) 97.8(7.2) 188.8(5.4)
Graph WaveNet 135.6(0.0) 290.7(0.1) 135.7(0.1) 290.6(0.1) 135.7(0.1) 290.6(0.1) 135.7(0.0) 290.5(0.1)
MTGNN 122.9(8.8) 268.2(16.7) 79.9(5.3) 167.6(19.7) 106.1(25.3) 225.7(62.3) 107.2(21.4) 233.5(45.2)
KoopmanAE 147.5(0.7) 252.8(0.3) 152.0(7.3) 253.9(1.8) 147.8(0.6) 252.9(0.2) 146.9(1.7) 252.7(0.5)
ST-KMRN 68.3(1.3) 148.0(4.0) 67.9(2.1) 145.6(4.0) 68.1(1.8) 149.0(2.8) 72.3(2.8) 157.7(5.6)
RelErr -25.4% -20.6% -15.0% -13.1% -21.5% -16.0% -26.1% -16.5%
RelErrGW -49.6% -49.1% -50.0% -49.9% -49.8% -48.7% -46.7% -45.7%
50
4.4.3 Ablation Study
Table 4.6: Ablation study results on YellowCab (Obs Ratio = 0.8) with the relative change of errors after
removing each component.
Horizon 6h 10d
Metric MAE RMSE MAE RMSE
w/o Self-Attn 11.23(2.09)
+58.84%
15.76(2.12)
+26.38%
14.41(2.36)
+14.73%
23.64(2.16)
+4.69%
w/o Koopman 8.70(0.59)
+23.06%
18.63(0.81)
+49.40%
14.67(0.60)
+16.80%
26.42(0.81)
+17.01%
w/o ups/ds 7.94(0.51)
+12.31%
12.87(0.52)
+3.21%
13.45(0.20)
+7.09%
23.18(0.12)
+2.66%
ST-KMRN 7.07(0.64) 12.47(0.58) 12.56(0.66) 22.58(0.71)
Set-up We conduct the ablation study on the YellowCab dataset with 80% observation ratio to evaluate
the contribution of proposed components. By removing each component from ST-KMRN, we have the
following settings: (1) w/o Self-Attn: ST-KMRN without the self-attention module. (2) w/o Koopman:
ST-KMRN without the Koopman module. The prediction will only be performed from the neural network
module. (3) w/o ups/ds: ST-KMRN without upsampling and downsampling modules. Only the decoder of
the target resolution will forecast. We repeat experiments of each setting for 3 times and report results in
Table 4.6. Higher increase of errors after removing one module indicates larger contributions.
Discussion From Table 4.6, we observe that (1) the self-attention module applied on representations of
different resolutions significantly improves performance. It allows information exchange among resolutions,
and it contributes most on shorter horizons (6h). (2) The combination of both Koopman and neural
network modules brings significant boosts to the forecasting performance for both short (6h) and long (10d)
prediction horizons. It captures the correct patterns for each temporal resolution, which we will detail
in Section "Interpretability of Koopman Module". (3) The upsampling and downsampling modules can
improve performance on all forecasting horizons as (a) the downsampling module regularizes the output
by minimizing the difference between its aggregated values with ground truth values in lower resolutions;
51
Table 4.7: Effect of the number of resolutions on prediction performance.
30min 6h 1d 10d
MAE RMSE MAE RMSE MAE RMSE MAE RMSE
1 Res 13.80(0.22)
+12.47%
19.32(0.32)
+6.98%
7.66(0.20)
+1.73%
12.88(0.18)
-6.19%
13.25(0.64)
-0.53%
22.68(0.92)
-1.69%
13.47(0.64)
-1.17%
23.66(0.88)
-1.78%
2 Res 13.00(0.87)
+5.95%
18.65(0.44)
+3.27%
8.40(1.52)
+11.55%
13.78(0.97)
+0.36%
13.58(0.36)
+1.95%
22.91(0.17)
-0.69%
13.95(0.51)
+2.35%
23.98(0.34)
-0.46%
3 Res 12.27(0.64) 18.06(0.47) 7.53(0.60) 13.73(0.99) 13.32(0.92) 23.07(1.32) 13.63(0.81) 24.09(1.05)
(a) 30-min (b) 6-hour (c) 1-day
Figure 4.4: Eigenvalues of Koopman matrices in the learned hidden space of input sequences in different
resolutions. For each subfigure (a)(b)(c), the left part shows the distribution of eigenvalues on the complex
plane with the unit circle, and the right part displays the distribution of maximum magnitudes of complex
eigenvalues w.r.t periods (period = 2π/angle × time window).
(b) the upsampling module is trained to predict local temporal patterns and can fix the errors produced by
the decoder, which forecasts values on all steps at once.
4.4.4 Effect of the number of resolutions on the performance.
We evaluate the effect of the number of temporal resolutions on the prediction performance with the
YellowCab dataset. Results are shown in Table 4.7. We find that when the number of available temporal
resolutions is reduced gradually from the most coarse resolution, the prediction errors (MAE/RMSE)
increase by 12.47%/6.98% (for 1 resolution) and 5.95%/3.27% (for 2 resolutions) respectively for the prediction
into future 30 minutes. However, prediction performance for longer terms (6h/1d/10d) does not differ
significantly. Results demonstrate that the increase of the number of resolutions in observations will benefit
short-term prediction performance.
52
4.4.5 Interpretability of Koopman Module: Revealing Dynamics in Each Resolution
Set-up Since the Koopman module models the temporal dynamics as a linear system in the hidden space,
the derived Koopman matrix describing the system can provide rich interpretable information. With eigen
decomposition applied to the Koopman matrix, we can decompose the system dynamics into components
with different magnitudes and periods. An eigenvalue λejθ corresponds to a component f(t) = λ
t
e
jθt with
magnitude λ
t
and frequency θ (i.e. period T =
2π
θ
).
Discussion Figure 4.4 displays the distribution of eigenvalues of Koopman matrices for the input data of
each resolution. We observe that: (1) For data in the resolution of 30-min, 6-hour, and 1-day, the maximum of
magnitudes of eigenvalues reaches its peak value around the period of 24-hour, 1-day, and 7-day respectively
(we ignore the eigenvalues with zero angles since they represent components with infinite periods and thus
are non-periodic). This matches our observation that the taxi demand data shows strong daily and weekly
patterns. (2) For data in the resolution of 30-min and 6-hour, we can still notice eigenvalues with large
magnitudes around the 1-week period, but it is hard to distinguish them from other components including
non-periodic ones. The reason is that in data with high temporal resolutions, long periods correspond to
large numbers of time steps T and frequency values θ close to 0. When we switch to data with a lower
temporal resolution (1-day), the pattern of the same period can be represented by components with larger
distances to non-periodic components in the spectral domain and are easier to identify. This again implies
the necessity of utilizing multi-resolution data for better modeling temporal patterns with various periods.
53
4.5 Conclusion
We propose Spatiotemporal Koopman Multi-Resolution Network (ST-KMRN), which boosts the longsequence forecasting with enhanced modeling of inter-resolution and intra-resolution dynamics in multiresolution data. To accomplish this, we introduce self-attention among representations of different resolutions together with upsampling and downsampling modules for better utilization of inter-resolution
contextual information. Meanwhile, we improve the modeling of intra-resolution dynamics with the
combination of Koopman theory based modeling. ST-KMRN achieves state-of-the-art performance on the
long-sequence forecasting tasks from multiple real-world spatiotemporal datasets. Limitations of our work
include that the available resolutions used by the model are fixed and are usually dependent on the task
and data availability, and the model only applies to data with regular time steps. Our future works include
extending the framework to find optimal data resolutions to guide the data collection process and enabling
it to model and forecast irregular spatiotemporal sequences.
54
Chapter 5
Cross-Node Federated Graph Neural Network for Spatio-Temporal Data
Modeling
5.1 Introduction
Modeling the dynamics of spatio-temporal data generated from networks of edge devices or nodes (e.g.
sensors, wearable devices and the Internet of Things (IoT) devices) is critical for various applications
including traffic flow prediction [54, 108], forecasting [87, 3], and user activity detection [104, 59]. While
existing works on spatio-temporal dynamics modeling [5, 40, 6] assume that the model is trained with
centralized data gathered from all devices, the volume of data generated at these edge devices precludes
the use of such centralized data processing, and calls for decentralized processing where computations
on the edge can lead to significant gains in improving the latency. In addition, in case of spatio-temporal
forecasting, the edge devices need to leverage the complex inter-dependencies to improve the prediction
performance. Moreover, with increasing concerns about data privacy and its access restrictions due to
existing licensing agreements, it is critical for spatio-temporal modeling to utilize decentralized data, yet
leveraging the underlying relationships for improved performance.
Although recent works in federated learning (FL) [35] provides a solution for training a model with
decentralized data on multiple devices, these works either do not consider the inherent spatio-temporal
55
dependencies [66, 51, 37] or only model it implicitly by imposing the graph structure in the regularization
on model weights [92], the latter of which suffers from the limitation of regularization based methods due
to the assumption that graphs only encode similarity of nodes [42], and cannot operate in settings where
only a fraction of devices are observed during training (inductive learning setting). As a result, there is a
need for an architecture for spatio-temporal data modeling which enables reliable computation on the edge,
while maintaining the data decentralized.
To this end, leveraging recent works on federated learning [35], we introduce the cross-node federated
learning requirement to ensure that data generated locally at a node remains decentralized. Specifically,
our architecture – Cross-Node Federated Graph Neural Network (CNFGNN), aims to effectively model
the complex spatio-temporal dependencies under the cross-node federated learning constraint. For this,
CNFGNN decomposes the modeling of temporal and spatial dependencies using an encoder-decoder model
on each device to extract the temporal features with local data, and a Graph Neural Network (GNN) based
model on the server to capture spatial dependencies among devices.
As compared to existing federated learning techniques that rely on regularization to incorporate spatial
relationships, CNFGNN leverages an explicit graph structure using a graph neural network-based (GNNs)
architecture, which leads to performance gains. However, the federated learning (data sharing) constraint
means that the GNN cannot be trained in a centralized manner, since each node can only access the data
stored on itself. To address this, CNFGNN employs Split Learning [91] to train the spatial and temporal
modules. Further, to alleviate the associated high communication cost incurred by Split Learning, we
propose an alternating optimization-based training procedure of these modules, which incurs only half
the communication overhead as compared to a comparable Split Learning architecture. Here, we also use
Federated Averaging (FedAvg) [66] to train a shared temporal feature extractor for all nodes, which leads to
improved empirical performance.
Our main contributions are as follows :
56
1. We propose Cross-Node Federated Graph Neural Network (CNFGNN), a GNN-based federated
learning architecture that captures complex spatio-temporal relationships among multiple nodes
while ensuring that the data generated locally remains decentralized at no extra computation cost at
the edge devices.
2. Our modeling and training procedure enables GNN-based architectures to be used in federated
learning settings. We achieve this by disentangling the modeling of local temporal dynamics on
edge devices and spatial dynamics on the central server, and leverage an alternating optimizationbased procedure for updating the spatial and temporal modules using Split Learning and Federated
Averaging to enable effective GNN-based federated learning.
3. We demonstrate that CNFGNN achieves the best prediction performance (both in transductive and
inductive settings) at no extra computation cost on edge devices with modest communication cost, as
compared to the related techniques on a traffic flow prediction task.
5.2 Cross-Node Federated Graph Neural Network
Table 5.1 summarizes notations used in this paper and their definitions.
5.2.1 Problem Formulation
Given a dataset with a graph G = (V, E), a feature tensor X ∈ R
|V|×... and a label tensor Y ∈ R
|V|×..., the
task is defined on the dataset with X as the input and Y as the prediction target. We consider learning a
model under the cross-node federated learning constraint: node feature xi = Xi,..., node label yi = Yi,...,
and model output yˆi are only visible to the node i.
One typical task that requires the cross-node federated learning constraint is the prediction of spatiotemporal data generated by a network of sensors. In such a scenario, V is the set of sensors and E describes
57
Table 5.1: Table of notations.
Notation Definition
G = (V, E) Graph G defined with the set of nodes V and the set of edges E.
X Tensor of node features. X ∈ R
|V|×...
.
xi Features of the i-th node.
Y Tensor of node labels for the task. Y ∈ R
|V|×...
.
yi Labels of the i-th node.
yˆi Model prediction output for the i-th node.
Rg/Rc/Rs Maximum number of global/server/client training rounds.
θ
(rg)
GN Weights of the server-side Graph Network in the rg-th global training round.
θ¯
(rg)
c Aggregated weights of client models in the rg-th global training round.
hc,i Local embedding of the input sequence of the i-th node.
hG,c,i Embedding of the i-th node propagated with the server-side GN.
ℓi Loss calculated on the i-th node.
ηs Learning rate for training θ
(rg)
GN .
ηc Learning rate for training θ¯
(rg)
c .
relations among sensors (e.g. eij ∈ E if and only if the distance between vi and vj is below some threshold).
The feature tensor xi ∈ R
m×D represents the i-th sensor’s records in the D-dim space during the past m
time steps, and the label yi ∈ R
n×D represents the i-th sensor’s records in the future n time steps. Since
records collected on different sensors owned by different users/organizations may not be allowed to be
shared due to the need for edge computation or licensing issues on data access, it is necessary to design an
algorithm modeling the spatio-temporal relation without any direct exchange of node-level data.
5.2.2 Proposed Method
We now introduce our proposed Cross-Node Federated Graph Neural Network (CNFGNN) model. Here, we
begin by disentangling the modeling of node-level temporal dynamics and server-level spatial dynamics as
follows: (i) (Figure 5.1c) on each node, an encoder-decoder model extracts temporal features from data on
the node and makes predictions; (ii) (Figure 5.1b) on the central server, a Graph Network (GN) [6] propagates extracted node temporal features and outputs node embeddings, which incorporate the relationship
information amongst nodes. (i) has access to the not shareable node data and is executed on each node
58
Server
(1) Node ...
(1) (2) (3) (4)
Node
(a) Overview of the training
procedure.
GN GN
FedAvg:
(2) (1) (3)
(4)
(b) Server-side Graph Network (GN).
...
...
...
...
(1)
(2) (3)
(4)
(c) Encoder-decoder on the i-th node.
Figure 5.1: Cross-Node Federated Graph Neural Network. (a) In each round of training, we alternately train
models on nodes and the model on the server. More specifically, we sequentially execute: (1) Federated
learning of on-node models. (2) Temporal encoding update. (3) Split Learning of GN. (4) On-node graph
embedding update. (b) Detailed view of the server-side GN model for modeling spatial dependencies in
data. (c) Detailed view of the encoder-decoder model on the i-th node.
59
Algorithm 1 Training algorithm of CNFGNN on the server side.
Input: Initial server-side GN weights θ
(0)
GN , initial
client model weights θ¯
(0)
c = {θ¯
(0),enc
c , θ¯
(0),dec
c }, the
maximum number of global rounds Rg, the
maximum number of client rounds Rc, the
maximum number of server rounds Rs, server-side
learning rate ηs, client learning rate ηc.
Output: Trained server-side GN weights θ
(Rg)
GN ,
trained client model weights θ¯
(Rg)
c .
Server executes:
1: Initialize server-side GN weights with θ
(0)
GN . Initialize client model weights with θ¯
(0)
c .
2: for each node i ∈ V in parallel do
3: Initialize client model θ
(0)
c,i = θ¯
(0)
c .
4: Initialize graph encoding on node hG,c,i =
h
(0)
G,c,i.
5: end for
6: for global round rg = 1, 2, . . . , Rg do
7: // (1) Federated learning of on-node models.
8: for each client i ∈ V in parallel do
9: θc,i ← ClientUpdate(i, Rc, ηc).
10: end for
11: θ¯
(rg)
c ←
P
i∈V
Ni
N
θc,i.
12: for each client i ∈ V in parallel do
13: Initialize client model: θ
(0)
c,i = θ¯
(rg)
c .
14: end for
15: // (2) Temporal encoding update.
16: for each client i ∈ V in parallel do
17: hc,i ← ClientEncode(i).
18: end for
19: // (3) Split Learning of GN.
20: Initialize θ
(rg,0)
GN = θ
(rg−1)
GN .
21: for server round rs = 1, 2, . . . , Rs do
22: {hG,c,i|i ∈ V} ← GN({hc,i|i ∈
V}; θ
(rg,rs−1)
GN ).
23: for each client i ∈ V in parallel do
24: ∇hG,c,i ℓi ← ClientBackward(
i,hG,c,i).
25: ∇θ
(rg,rs−1)
GN
ℓi ← hG,c,i.backward(
∇hG,c,i ℓi
).
26: end for
27: ∇θ
(rg,rs−1)
GN
ℓ ←
P
i∈V ∇θ
(rg,rs−1)
GN
ℓi
.
28: θ
(rg,rs)
GN ← θ
(rg,rs−1)
GN
- ηs∇θ
(rg,rs−1)
GN
ℓ.
29: end for
30: θ
(rg)
GN ← θ
(rg,Rs)
GN .
31: // (4) On-node graph embedding update.
32: {hG,c,i|i ∈ V} ←
GN({hc,i|i ∈ V}; θ
(rg)
GN ).
33: for each client i ∈ V in parallel do
34: Set graph encoding on client as hG,c,i.
35: end for
36: end for
locally. (ii) only involves the upload and download of smashed features and gradients instead of the raw
data on nodes. This decomposition enables the exchange and aggregation of node information under the
cross-node federated learning constraint.
5.2.2.1 Modeling of Node-Level Temporal Dynamics
We modify the Gated Recurrent Unit (GRU) based encoder-decoder architecture in [16] for the modeling of
node-level temporal dynamics on each node. Given an input sequence xi ∈ R
m×D on the i-th node, an
60
Algorithm 2 Training algorithm of CNFGNN on the client side.
ClientUpdate(i, Rc, ηc):
1: for client round rc = 1, 2, . . . , Rc do
2: h
(rc)
c,i ← Encoderi(xi
; θ
(rc−1),enc
c,i ).
3: yˆi ← Decoderi(
xi,m, [h
(rc)
c,i ; hG,c,i]; θ
(rc−1),dec
c,i ).
4: ℓi ← ℓ(yˆi
, y).
5: θ
(rc)
c,i ← θ
(rc−1)
c,i − ηc∇θ
(rc−1)
c,i
ℓi
.
6: end for
7: θc,i = θ
(Rc)
c,i .
8: return θc,i to server.
ClientEncode(i):
1: return hc,i = Encoderi(xi
; θ
enc
c,i ) to server.
ClientBackward(i, hG,c,i):
1: yˆi ← Decoderi(xi,m, [hc,i; hG,c,i]; θ
dec
c,i ).
2: ℓi ← ℓ(yˆi
, y).
3: return ∇hG,c,i ℓi to server.
encoder sequentially reads the whole sequence and outputs the hidden state hc,i as the summary of the
input sequence according to Equation 5.1.
hc,i = Encoderi(xi
, h
(0)
c,i ),
(5.1)
where h
(0)
c,i is a zero-valued initial hidden state vector.
To incorporate the spatial dynamics into the prediction model of each node, we concatenate hc,i with
the node embedding hG,c,i generated from the procedure described in 5.2.2.2, which contains spatial
information, as the initial state vector of the decoder. The decoder generates the prediction yˆi
in an
auto-regressive way starting from the last frame of the input sequence xi,m with the concatenated hidden
state vector.
yˆi = Decoderi(xi,m, [hc,i; hG,c,i]). (5.2)
We choose the mean squared error (MSE) between the prediction and the ground truth values as the loss
function, which is evaluated on each node locally.
61
5.2.2.2 Modeling of Spatial Dynamics
To capture the complex spatial dynamics, we adopt Graph Networks (GNs) proposed in [6] to generate node
embeddings containing the relational information of all nodes. The central server collects the hidden state
from all nodes {hc,i | i ∈ V} as the input to the GN. Each layer of GN updates the input features as follows:
e
′
k = ϕ
e
(ek, vrk
, vsk
, u) e
′
i = ρ
e→v
(E′
i
)
v
′
i = ϕ
v
(e
′
i
, vi
, u) e
′ = ρ
e→u
(E′
)
u
′ = ϕ
u
(e
′
, v
′
, u) v
′ = ρ
v→u
(V
′
)
, (5.3)
where ek, vi
, u are edge features, node features and global features respectively. ϕ
e
, ϕv
, ϕu
are neural
networks. ρ
e→v
, ρe→u
, ρv→u
are aggregation functions such as summation. As shown in Figure 5.1b, we
choose a 2-layer GN with residual connections for all experiments. We set vi = hc,i, ek = Wrk,sk
(W is the
adjacency matrix) , and assign the empty vector to u as the input of the first GN layer. The server-side GN
outputs embeddings {hG,c,i | i ∈ V} for all nodes, and sends the embedding of each node correspondingly.
5.2.2.3 Alternating Training of Node-Level and Spatial Models
One challenge brought about by the cross-node federated learning requirement and the server-side GN
model is the high communication cost in the training stage. Since we distribute different parts of the
model on different devices, Split Learning proposed by [91] is a potential solution for training, where
hidden vectors and gradients are communicated among devices. However, when we simply train the model
end-to-end via Split Learning, the central server needs to receive hidden states from all nodes and to send
node embeddings to all nodes in the forward propagation, then it must receive gradients of node embeddings
from all nodes and send back gradients of hidden states to all nodes in the backward propagation. Assume
all hidden states and node embeddings have the same size S, the total amount of data transmitted in each
training round of the GN model is 4|V|S.
62
To alleviate the high communication cost in the training stage, we instead alternately train models on
nodes and the GN model on the server. More specifically, in each round of training, we (1) fix the node
embedding hG,c,i and optimize the encoder-decoder model for Rc rounds, then (2) we optimize the GN
model while fixing all models on nodes. Since models on nodes are fixed, hc,i stays constant during the
training of the GN model, and the server only needs to fetch hc,i from nodes before the training of GN
starts and only to communicate node embeddings and gradients. Therefore, the average amount of data
transmitted in each round for Rs rounds of training of the GN model reduces to 2+2Rs
Rs
|V|S. We provide
more details of the training procedure in Algorithm 1 and Algorithm 2.
To more effectively extract temporal features from each node, we also train the encoder-decoder models
on nodes with the FedAvg algorithm proposed in [66]. This enables all nodes to share the same feature
extractor and thus share a joint hidden space of temporal features, which avoids the potential overfitting of
models on nodes and demonstrates faster convergence and better prediction performance empirically.
5.3 Experiments
We evaluate the performance of CNFGNN and all baseline methods on the traffic forecasting task, which
is an important application for spatio-temporal data modeling. The primary challenge in FL is to respect
constraints on data sharing and manipulation. These constraints can occur in scenarios where data contains
sensitive information, such as financial data owned by different institutions. Due to the sensitivity of data,
datasets from such scenarios are proprietary and hardly offer public access. Therefore, we demonstrate the
applicability of our proposed model on the traffic dataset, which is a good representative example of data
with spatio-temporal correlations, and has been extensively studied in spatio-temporal forecasting works
without FL constraints [54, 108]. Our proposed model is general and applicable to other spatio-temporal
datasets with sensitive information.
63
We reuse the following two real-world large-scale datasets in [54] and follow the same preprocessing
procedures: (1) PEMS-BAY: This dataset contains the traffic speed readings from 325 sensors in the Bay
Area over 6 months from Jan 1st, 2017 to May 31st, 2017. (2) METR-LA: This dataset contains the traffic
speed readings from 207 loop detectors installed on the highway of Los Angeles County over 4 months
from Mar 1st, 2012 to Jun 30th, 2012.
For both datasets, we construct the adjacency matrix of sensors using the Gaussian kernel with a
threshold: Wi,j = di,j if di,j >= κ else 0, where di,j = exp (−
dist(vi,vj )
2
σ2 ), dist(vi
, vj ) is the road network
distance from sensor vi to sensor vj , σ is the standard deviation of distances and κ is the threshold. We set
κ = 0.1 for both datasets.
We aggregate traffic speed readings in both datasets into 5-minute windows and truncate the whole
sequence to multiple sequences with length 24. The forecasting task is to predict the traffic speed in the
following 12 steps of each sequence given the first 12 steps. We show the statistics of both datasets in
Table 5.2.
Table 5.2: Statistics of datasets PEMS-BAY and METR-LA.
Dataset # Nodes # Directed Edges # Train Seq # Val Seq # Test Seq
PEMS-BAY 325 2369 36465 5209 10419
METR-LA 207 1515 23974 3425 6850
We show the histograms of traffic speed on different nodes of PEMS-BAY and METR-LA in Figure 5.2a
and Figure 5.2b. For each dataset, we only show the first 100 nodes ranked by their IDs for simplicity.
The histograms show that the data distribution varies with nodes, thus data on different nodes are not
independent and identically distributed.
5.3.1 Spatio-Temporal Data Modeling: Traffic Flow Forecasting
Baselines Here we introduce the settings of baselines and our proposed model CNFGNN. Unless noted
otherwise, all models are optimized using the Adam optimizer with the learning rate 1e-3.
64
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 0
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 1
0 20 40 60 80
10
3
10
4
10
5
Node 2
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 3
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 4
0 20 40 60 80
10
3
10
4
10
5
Node 5
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 6
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 7
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 8
0 20 40 60 80
10
3
10
4
10
5
Node 9
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 10
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 11
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 12
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 13
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 14
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 15
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 16
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 17
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 18
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 19
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 20
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 21
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 22
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 23
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 24
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 25
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 26
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 27
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 28
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 29
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 30
0 20 40 60 80
10
3
10
4
10
5
Node 31
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 32
0 20 40 60 80
10
3
10
4
10
5
Node 33
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 34
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 35
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 36
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 37
0 20 40 60 80
10
3
10
4
10
5
Node 38
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 39
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 40
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 41
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 42
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 43
0 20 40 60 80
10
3
10
4
10
5
Node 44
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 45
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 46
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 47
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 48
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 49
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 50
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 51
0 20 40 60 80
10
3
10
4
10
5
Node 52
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 53
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 54
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 55
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 56
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 57
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 58
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 59
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 60
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 61
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 62
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 63
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 64
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 65
0 20 40 60 80
10
3
10
4
10
5
Node 66
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 67
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 68
0 20 40 60 80
10
3
10
4
10
5
Node 69
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 70
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 71
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 72
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 73
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 74
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 75
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 76
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 77
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 78
0 20 40 60 80
10
3
10
4
10
5
Node 79
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 80
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 81
0 20 40 60 80
10
3
10
4
10
5
Node 82
0 20 40 60 80
10
3
10
4
10
5
Node 83
0 20 40 60 80
10
3
10
4
10
5
Node 84
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 85
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 86
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 87
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 88
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 89
0 20 40 60 80
10
3
10
4
10
5
Node 90
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 91
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 92
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 93
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 94
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 95
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 96
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 97
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 98
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 99
(a) PEMS-BAY
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 0
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 1
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 2
0 20 40 60 80
10
3
10
4
Node 3
0 20 40 60 80
10
3
10
4
Node 4
0 20 40 60 80
10
2
10
3
10
4
Node 5
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 6
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 7
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 8
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 9
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 10
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 11
0 20 40 60 80
10
3
10
4
10
5
Node 12
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 13
0 20 40 60 80
10
2
10
3
10
4
10
5 Node 14
0 20 40 60 80
10
2
10
3
10
4
Node 15
0 20 40 60 80
10
2
10
3
10
4
10
5 Node 16
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 17
0 20 40 60 80
10
2
10
3
10
4
Node 18
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 19
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 20
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 21
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 22
0 20 40 60 80
10
3
10
4
Node 23
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 24
0 20 40 60 80
10
2
10
3
10
4
10
5 Node 25
0 20 40 60 80
10
2
10
3
10
4
Node 26
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 27
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 28
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 29
0 20 40 60 80
10
2
10
3
10
4
10
5 Node 30
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 31
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 32
0 20 40 60 80
10
2
10
3
10
4
Node 33
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 34
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 35
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 36
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 37
0 20 40 60 80
10
3
10
4
Node 38
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 39
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 40
0 20 40 60 80
10
3
10
4
10
5
Node 41
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 42
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 43
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 44
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 45
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 46
0 20 40 60 80
10
2
10
3
10
4
Node 47
0 20 40 60 80
10
2
10
3
10
4
10
5 Node 48
0 20 40 60 80
10
2
10
3
10
4
10
5 Node 49
0 20 40 60 80
10
2
10
3
10
4
10
5 Node 50
0 20 40 60 80
10
2
10
3
10
4
Node 51
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 52
0 20 40 60 80
10
2
10
3
10
4
Node 53
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 54
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 55
0 20 40 60 80
10
2
10
3
10
4
Node 56
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 57
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 58
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 59
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 60
0 20 40 60 80
10
2
10
3
10
4
Node 61
0 20 40 60 80
10
3
10
4
Node 62
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 63
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 64
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 65
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 66
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 67
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 68
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 69
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 70
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 71
0 20 40 60 80
10
3
10
4
Node 72
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 73
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 74
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 75
0 20 40 60 80
10
2
10
3
10
4
Node 76
0 20 40 60 80
10
3
10
4
10
5 Node 77
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 78
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 79
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 80
0 20 40 60 80
10
2
10
3
10
4
Node 81
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 82
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 83
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 84
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 85
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 86
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 87
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 88
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 89
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 90
0 20 40 60 80
10
1
10
2
10
3
10
4
Node 91
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 92
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 93
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 94
0 20 40 60 80
10
2
10
3
10
4
Node 95
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 96
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 97
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 98
0 20 40 60 80
10
2
10
3
10
4
10
5
Node 99
(b) METR-LA
Figure 5.2: The histograms of data on the first 100 nodes ranked by ID.
• GRU (centralized): Gated Recurrent Unit (GRU) model trained with centralized sensor data. The
GRU model with 63K parameters is a 1-layer GRU with hidden dimension 100, and the GRU model
with 727K parameters is a 2-layer GRU with hidden dimension 200.
• GRU + GN (centralized): a model directly combining GRU and GN trained with centralized data,
whose architecture is similar to CNFGNN but all GRU modules on nodes always share the same
weights. We see its performance as the upper bound of the performance of CNFGNN.
• GRU (local): for each node we train a GRU model with only the local data on it.
• GRU + FedAvg: a GRU model trained with the Federated Averaging algorithm [66]. We select 1 as
the number of local epochs.
• GRU + FMTL: for each node we train a GRU model using the federated multi-task learning (FMTL)
with cluster regularization [92] given by the adjacency matrix. More specifically, the cluster regularization (without the L2-norm regularization term) takes the following form:
R(W, Ω) = λtr(WΩWT
). (5.4)
65
Given the constructed adjacency matrix A, Ω =
1
|V|(D − A) = 1
|V|L, where D is the degree matrix
and L is the Laplacian matrix. Equation 5.4 can be reformulated as:
R(W, Ω) = λtr(WΩWT
)
=
λ
|V|tr(W LWT
)
=
λ
|V|tr(X
i∈V
wi
X
j̸=i
aijwT
i −
X
j̸=i
wiaijwT
j
)
= λ1(
X
i∈V
X
j̸=i
αi,j ⟨wi
, wi − wj ⟩).
(5.5)
We implement the cluster regularization via sharing model weights between each pair of nodes
connected by an edge and select λ1 = 0.1. For each baseline, we have 2 variants of the GRU model to
show the effect of on-device model complexity: one with 63K parameters and the other with 727K
parameters. For CNFGNN, the encoder-decoder model on each node has 64K parameters and the GN
model has 1M parameters.
• CNFGNN We use a GRU-based encoder-decoder model as the model on nodes, which has 1 GRU
layer and hidden dimension 64. We use a 2-layer Graph Network (GN) with residual connections as
the Graph Neural Network model on the server side. We use the same network architecture for the
edge/node/global update function in each GN layer: a multi-layer perceptron (MLP) with 3 hidden
layers, whose sizes are [256, 256, 128] respectively. We choose Rc = 1, Rs = 20 for experiments on
PEMS-BAY, and Rc = 1, Rs = 1 for METR-LA.
Calculation of Communication Cost We denote R as the number of communication rounds for one
model to reach the lowest validation error in the training stage.
66
GRU + FMTL Using Equation 5.5, in each communication round, each pair of nodes exchange their
model weights, thus the total communicated data amount is calculated as:
R × #nonself directed edges × size of node model weights. (5.6)
We list corresponding parameters in Table 5.3.
CNFGNN (AT + FedAvg) In each communication round, the central server fetches and sends back model
weights to each node for Federated Averaging, and transmits hidden vectors and gradients for Split Learning.
The total communicated data amount is calculated as:
R × (#nodes × size of node model weights × 2
+ (1 + 2 ∗ server round + 1) × #nodes × hidden state size).
(5.7)
We list corresponding parameters in Table 5.4.
CNFGNN (SL) In each communication round, each node sends and fetches hidden vectors and graidents
twice (one for encoder, the other for decoder) and the total communicated data amount is:
R × 2 × 2 × #nodes × hidden state size. (5.8)
We list corresponding parameters in Table 5.5.
67
CNFGNN (SL + FedAvg) Compared to CNFGNN (SL), the method has extra communcation cost for
FedAvg in each round, thus the total communicated data amount is:
R × (#nodes × size of node model weights × 2
+ 2 × 2 × #nodes × hidden state size).
(5.9)
We list corresponding parameters in Table 5.6.
CNFGNN (AT, w/o FedAvg) Compared to CNFGNN (AT + FedAvg), there is no communcation cost for
the FedAvg part, thus the total communcated data amount is:
R × (1 + 2 ∗ server round + 1) × #nodes × hidden state size.
(5.10)
We list corresponding parameters in Table 5.7.
Table 5.3: Parameters used for calculating the communication cost of GRU + FMTL.
Method GRU (63K) + FMTL GRU (727K) + FMTL
Node Model Weights Size (GB) 2.347E-4 2.708E-3
PEMS-BAY
#Nonself Directed Edges 2369
R 104 56
Train Comm Cost (GB) 57.823 359.292
METR-LA
#Nonself Directed Edges 1515
R 279 176
Train Comm Cost (GB) 99.201 722.137
Discussion Table 5.8 shows the comparison of forecasting performance and Table 5.9 shows the comparison of computation cost on device and communication cost of CNFGNN and baselines. We make the
following observations. Firstly, when we compare the best forecasting performance of each baseline over
the 2 GRU variants, GRU trained with FedAvg performs the worst in terms of forecasting performance
compared to GRU trained with centralized data and GRU trained with local data (4.432 vs 4.010/4.124 on
68
Table 5.4: Parameters used for calculating the communication cost of CNFGNN (AT + FedAvg).
Node Model
Weights Size (GB) 2.384E-4
PEMS-BAY
#Nodes 325
Hidden State Size (GB) 2.173E-3
Server Round 20
R 2
Train Comm Cost (GB) 237.654
METR-LA
#Nodes 207
Hidden State Size (GB) 1.429E-3
Server Round 1
R 46
Train Comm Cost (GB) 222.246
Table 5.5: Parameters used for calculating the communication cost of CNFGNN (SL).
PEMS-BAY
#Nodes 325
Hidden State Size (GB) 2.173E-3
R 31
Train Comm Cost (GB) 350.366
METR-LA
#Nodes 207
Hidden State Size (GB) 1.429E-3
R 65
Train Comm Cost (GB) 307.627
Table 5.6: Parameters used for calculating the communication cost of CNFGNN (SL + FedAvg).
Node Model
Weights Size (GB) 2.384E-4
PEMS-BAY
#Nodes 325
Hidden State Size (GB) 2.173E-3
R 7
Train Comm Cost (GB) 80.200
METR-LA
#Nodes 207
Hidden State Size (GB) 1.429E-3
R 71
Train Comm Cost (GB) 343.031
69
Table 5.7: Parameters used for calculating the communication cost of CNFGNN (AT, w/o FedAvg).
PEMS-BAY
#Nodes 325
Hidden State Size (GB) 2.173E-3
Server Round 20
R 44
Train Comm Cost (GB) 5221.576
METR-LA
#Nodes 207
Hidden State Size (GB) 1.429E-3
Server Round 1
R 49
Train Comm Cost (GB) 2434.985
Table 5.8: Comparison of performance on the traffic flow forecasting task. We use the Rooted Mean Squared
Error (RMSE) to evaluate the forecasting performance.
Method PEMS-BAY METR-LA
GRU (centralized, 63K) 4.124 11.730
GRU (centralized, 727K) 4.128 11.787
GRU + GN
(centralized, 64K + 1M) 3.816 11.471
GRU (local, 63K) 4.010 11.801
GRU (local, 727K) 4.152 12.224
GRU (63K) + FedAvg 4.512 12.132
GRU (727K) + FedAvg 4.432 12.058
GRU (63K) + FMTL 3.961 11.548
GRU (727K) + FMTL 3.955 11.570
CNFGNN (64K + 1M) 3.822 11.487
Table 5.9: Comparison of the computation cost on edge devices and the communication cost. We use the
amount of floating point operations (FLOPS) to measure the computational cost of models on edge devices.
We also show the total size of data/parameters transmitted in the training stage (Train Comm Cost) until
the model reaches its lowest validation error.
Method Comp Cost On
Device (GFLOPS)
PEMS-BAY METR-LA
RMSE Train Comm
Cost (GB) RMSE Train Comm
Cost (GB)
GRU (63K) + FMTL 0.159 3.961 57.823 11.548 99.201
GRU (727K) + FMTL 1.821 3.955 359.292 11.570 722.137
CNFGNN (64K + 1M) 0.162 3.822 237.654 11.487 222.246
70
PEMS-BAY and 12.058 vs 11.730/11.801 on METR-LA), showing that the data distributions on different
nodes are highly heterogeneous, and training one single model ignoring the heterogeneity is suboptimal.
Secondly, both the GRU+FMTL baseline and CNFGNN consider the spatial relations among nodes and
show better forecasting performance than baselines without relation information. This shows that the
modeling of spatial dependencies is critical for the forecasting task.
Lastly, CNFGNN achieves the lowest forecasting error on both datasets. The baselines that increases
the complexity of on-device models (GRU (727K) + FMTL) gains slight or even no improvement at the cost
of higher computation cost on edge devices and larger communication cost. However, due to its effective
modeling of spatial dependencies in data, CNFGNN not only has the largest improvement of forecasting
performance, but also keeps the computation cost on devices almost unchanged and maintains modest
communication cost compared to baselines increasing the model complexity on devices.
5.3.2 Inductive Learning on Unseen Nodes
Table 5.10: Inductive learning performance measured with rooted mean squared error (RMSE).
Method PEMS-BAY METR-LA
5% 25% 50% 75% 90% 5% 25% 50% 75% 90%
GRU (63K) + FedAvg 5.087 4.863 4.847 4.859 4.866 12.128 11.993 12.104 12.014 12.016
CNFGNN (64K + 1M) 5.869 4.541 4.598 4.197 3.942 13.931 12.013 11.815 11.676 11.629
Set-up Another advantage of CNFGNN is that it can conduct inductive learning and generalize to larger
graphs with nodes unobserved during the training stage. We evaluate the performance of CNFGNN under
the following inductive learning setting: for each dataset, we first sort all sensors based on longitudes,
then use the subgraph on the first η% of sensors to train the model and evaluate the trained model on the
entire graph. For each dataset we select η% = 25%, 50%, 75%. Over all baselines following the cross-node
federated learning constraint, GRU (local) and GRU + FMTL requires training new models on unseen nodes
and only GRU + FedAvg is applicable to the inductive learning setting.
71
5% 25% 50% 75% 90%
(a) PEMS-BAY
5% 25% 50% 75% 90%
(b) METR-LA
Figure 5.3: Visualization of subgraphs visible in training under different ratios.
Discussion Table 5.10 shows the performance of inductive learning of CNFGNN and GRU + FedAvg
baseline on both datasets. We observe that under most settings, CNFGNN outperforms the GRU + FedAvg
baseline (except on the METR-LA dataset with 25% nodes observed in training, where both models perform
similarly), showing that CNFGNN has the stronger ability of generalization.
We have further added results using 90% and 5% data on both datasets and we show the table of inductive
learning results as Table 5.10. We observe that: (1) With the portion of visible nodes in the training stage
increasing, the prediction error of CNFGNN decreases drastically. However, the increase of the portion of
visible nodes has negligible contribution to the performance of GRU + FedAvg after the portion surpasses
25%. Since increasing the ratio of seen nodes in training introduces more complex relationships among
nodes to the training data, the difference of performance illustrates that CNFGNN has a stronger capability
of capturing complex spatial relationships. (2) When the ratio of visible nodes in training is extremely
low (5%), there is not enough spatial relationship information in the training data to train the GN module
in CNFGNN, and the performance of CNFGNN may not be ideal. We visualize the subgraphs visible in
training under different ratios in Figure 5.3. However, as long as the training data covers a moderate portion
of the spatial information of the whole graph, CNFGNN can still leverage the learned spatial connections
72
0 10 20 30 40 50 60
Epoch
0.2
0.3
0.4
Val Loss
Centralized
SL
SL + FedAvg
AT, w/o FedAvg
AT + FedAvg
(a) PEMS-BAY
0 20 40 60 80
Epoch
0.300
0.325
0.350
0.375
0.400
Val Loss
Centralized
SL
SL + FedAvg
AT, w/o FedAvg
AT + FedAvg
(b) METR-LA
Figure 5.4: Validation loss during the training stage of different training strategies.
among nodes effectively and outperforms GRU+FedAvg. We empirically show that the necessary ratio can
vary for different datasets (25% for PEMS-BAY and 50% for METR-LA).
5.3.3 Ablation Study: Effect of Alternating Training and FedAvg on Node-Level and
Spatial Models
Table 5.11: Comparison of test error (RMSE) and the communication cost during training of different
training strategies of CNFGNN.
Method PEMS-BAY METR-LA
RMSE Train Comm
Cost (GB) RMSE Train Comm
Cost (GB)
Centralized 3.816 - 11.471 -
SL 3.914 350.366 12.186 307.627
SL + FedAvg 4.383 80.200 11.631 343.031
AT, w/o FedAvg 4.003 5221.576 11.912 2434.985
AT + FedAvg 3.822 237.654 11.487 222.246
Baselines We compare the effect of different training strategies of CNFGNN: (1) Centralized: CNFGNN trained with centralized data where all nodes share one single encoder-decoder. (2) Split Learning
(SL): CNFGNN trained with split learning [91], where models on nodes and the model on the server are
jointly trained by exchanging hidden vectors and gradients. (3) Split Learning + FedAvg (SL + FedAvg):
A variant of SL that synchronizes the weights of encoder-decoder modules periodically with FedAvg. (4)
73
Alternating training without Federated Averaging of models on nodes (AT, w/o FedAvg). (5) Alternating
training with Federated Averaging on nodes described in Section 5.2.2.3 (AT + FedAvg).
Discussion Figure 5.4 shows the validation loss during training of different training strategies on PEMSBAY and METR-LA datasets, and Table 5.11 shows their prediction performance and the communication cost
in training. We notice that (1) SL suffers from suboptimal prediction performance and high communication
costs on both datasets; SL + FedAvg does not have consistent results on both datasets and its performance
is always inferior to AT + FedAvg. AT + FedAvg consistently outperforms other baselines on both datasets,
including its variant without FedAvg. (2) AT + FedAvg has the lowest communication cost on METR-LA and
the 2nd lowest communication cost on PEMS-BAY, on which the baseline with the lowest communication
cost (SL + FedAvg) has a much higher prediction error (4.383 vs 3.822). Both illustrate that our proposed
training strategy, SL + FedAvg, achieves the best prediction performance as well as low communication
cost compared to other baseline strategies.
5.3.4 Ablation Study: Effect of Client Rounds and Server Rounds
Set-up We further investigate the effect of different compositions of the number of client rounds (Rs) in
Algorithm 2 and the number of server rounds (Rc) in Algorithm 1. To this end, we vary both Rc and Rs
over [1,10,20].
Discussion Figure 5.5 shows the forecasting performance (measured with RMSE) and the total communication cost in the training of CNFGNN under all compositions of (Rc, Rs) on the METR-LA dataset. We
observe that: (1) Models with lower Rc/Rs ratios (Rc/Rs < 0.5) tend to have lower forecasting errors
while models with higher Rc/Rs ratios (Rc/Rs > 2) have lower communication cost in training. This is
because the lower ratio of Rc/Rs encourages more frequent exchange of node information at the expense
of higher communication cost, while the higher ratio of Rc/Rs acts in the opposite way. (2) Models with
74
11.50 11.55 11.60 11.65 11.70 11.75
Forecasting RMSE
0
200
400
600
Comm Cost (GB)
(1, 1)
(1, 10)
(1, 20)
(10, 1)
(10, 10)
(10, 20)
(20, 1)
(20, 10)
(20, 20)
0.5 ≤ Rc/Rs ≤ 2
Rc/Rs < 0.5
Rc/Rs > 2
Figure 5.5: Effect of client rounds and server rounds (Rc, Rs) on forecasting performance and communication
cost.
similar Rc/Rs ratios have similar communication costs, while those with lower Rc values perform better,
corroborating our observation in (1) that frequent node information exchange improves the forecasting
performance.
5.4 Conclusion
We propose Cross-Node Federated Graph Neural Network (CNFGNN), which bridges the gap between
modeling complex spatio-temporal data and decentralized data processing by enabling the use of graph
neural networks (GNNs) in the federated learning setting. We accomplish this by decoupling the learning
of local temporal models and the server-side spatial model using alternating optimization of spatial and
temporal modules based on split learning and federated averaging. Our experimental results on traffic flow
prediction on two real-world datasets show superior performance as compared to competing techniques.
Our future work includes applying existing GNN models with sampling strategies and integrating them into
CNFGNN for large-scale graphs, extending CNFGNN to a fully decentralized framework, and incorporating
existing privacy-preserving methods for graph learning to CNFGNN, to enhance federated learning of
spatio-temporal dynamics.
75
Chapter 6
Sample-Level Prototypical Federated Learning
6.1 Introduction
Federated learning (FL) is a family of approaches enabling multiple clients to train machine learning models
collaboratively without exchanging data. FL has been emerging as a new machine learning paradigm due
to its advantages in protecting data privacy and satisfying data regulations [66, 35], and thus has drawn
great interest in cross-silo scenarios, where each client is an organization serving multiple users and thus
owns data from heterogeneous distributions. For example, in healthcare applications, each client (hospital)
gathers data from patients with different backgrounds and conditions. What exacerbates the challenge is
that real-world data samples usually lack explicit indicators specifying the distributions they are collected
from, or which samples are from the same distribution. Such scenarios call for fine-grained personalized
modeling for each data sample rather than personalization at the client level.
As a consequence of the heterogeneity of data generation processes across clients and the nonexchangeability of data, non-identically and non-independently distributed (non-IID) data has been a
fundamental statistical challenge in FL that leads to slow model convergence and performance drop [111,
52, 53]. According to [35], the data (feature X and label y) distribution on the i-th client can be rewritten
as Pi(X, y) = Pi(y|X)Pi(X) for a more precise characterization of differences in distributions, where
the variance of Pi(X) across clients is known as feature-distribution skew and the variance of Pi(y|X) as
76
concept shift. While both types of variances across clients have been abundantly discussed, the aforementioned internal data heterogeneity has hardly drawn attention in the FL context. Nevertheless, the latter is
ubiquitous in FL scenarios.
Existing works have developed various strategies to alleviate the data distribution variances across
clients. FedProx [52] and SCAFFOLD [36] improve the stability of federated training, but they still hold
the assumption that concept shift does not exist and train a unique global model for all clients. Instead,
personalized FL methods address the existence of concept shift and train a personalized model for each
client. Still, it is necessary for personalized FL methods to apply constraints on the heterogeneity of
Pi(y|X) to achieve better performance than training a model for each client with its local data only,
according to the impossibility result for Semi-Supervised Learning [7, 65]. Among them, relation-based
methods (pFedMe [94], perFedAvg [21], Ditto [50], FedMTL [92]) constrain differences of distributions
via regularizations between personalized and global model weights or among weights of client models.
Local-global methods (LG-FedAvg [57], FedRep [18]) restrict the distribution differences by partially sharing
model weights in personalized models. Cluster-based FL [84, 23] methods assign a cluster for each client
and train one shared model for all clients in each cluster. FedEM [65] further softens the cluster assignment
and models each Pi(y|x) as a mixture of underlying distributions, and the personalized model shared by
all samples on each client is a linear combination of components for each distribution. However, in all
these methods, samples on the same client still share one unified model, and the internal data heterogeneity
inside each client is neglected. Recent works (FedPCL [95], FedNH [20]) regard the labels in classification
tasks as the data domain indicators and learn representations for each label respectively. However, these
methods have difficulties generalizing to tasks with continuous labels, such as regression.
We propose Sample-Level Prototypical Federated Learning (SL-PFL), which considers the multi-domain
nature across data samples and provides a personalized model at the sample level. To address the challenge
that explicit data domain indicators are inaccessible or even not well-defined in real-world scenarios, we
77
develop a federated prototypical contrastive learning method that learns clustered representations as well as
soft domain assignments in a semi-supervised way. The sample-level personalized prediction model is then
realized as the ensemble of domain-specific models. Our contributions can be summarized as follows: (1)
We provide a fine-grained distribution factorization modeling the internal multi-domain property for clients
in personalized FL. (2) Based on the factorization, we develop a federated semi-supervised prototypical
contrastive learning method named Sample-Level Prototypical Federated Learning(SL-PFL), which provides
a fine-grained personalized model for each data sample without requiring explicit domain indicators in
training data. (3) With sample-level personalized models, SL-PFL achieves better performance than existing
FL methods with either a uniform global model or client-level personalized models on various regression and
classification tasks from real-world applications, where each client contains private data from heterogeneous
domains.
6.2 Proposed Method: Sample-Level Prototypical Federated Learning (SLPFL)
6.2.1 Problem Formulation
In the settings of federated learning, C clients collaboratively train models while keeping each client’s data
local. We denote the visible dataset of the c-th client as Dc = {(Xc,i, yc,i) | i = 1, . . . , nc}. It contains nc
samples from the local data distribution Pc. We can then formalize the target of FL as the optimal model
parameters that minimize the average true risk of all clients:
{Θ∗
1
, . . . , Θ∗
C} = arg min
{Θ1,...,ΘC }
1
C
X
C
c=1
E(X,y)∼Pc
ℓc(X, y; Θc), (6.1)
78
where Θc is the model parameter on the c-th client. With the common Empirical Risk Minimization
technique (ERM), we usually solve the following optimization problem within the training process:
{Θ∗
1
, . . . , Θ∗
C} = arg min
{Θ1,...,ΘC }
1
C
X
C
c=1
1
nc
Xnc
i=1
ℓc(Xc,i, yc,i; Θc), (6.2)
where ℓc is the loss function for the task of the c-th client.
For common tasks, such as classification and regression, the loss function ℓ(X, y; Θ) can be viewed as
the negative log-likelihood − log P(X, y; Θ) from the probabilistic perspective. Therefore, Equation 6.2
can be rewritten as:
{Θ∗
1
, . . . , Θ∗
C}
= arg min
{Θ1,...,ΘC }
−
1
C
X
C
c=1
1
nc
Xnc
i=1
log Pc(Xc,i, yc,i; Θc),
(6.3)
where Pc(Xc,i, yc,i; Θc) is the joint probability of (Xc,i, yc,i) on the c-th client parameterized by Θc. In
most cases, we have the same assumption on the types of joint distributions for all clients (for example,
categorical distribution for classification and Gaussian for regression). Thus, we will simplify Pc as P in
the following text.
6.2.2 Sample-Level Factorization of Data Distribution
To better illustrate the data distribution, we rewrite P(Xc,i, yc,i; Θc) as P(yc,i|Xc,i; Θc)P(Xc,i; Θc). As
shown in [65], when there is no assumption on relations between local data distributions, the FL problem
can be reduced to C parallel semi-supervised learning (SSL) problems. Each of the problems can learn by
using (1) the labeled local dataset and (2) unlabeled datasets from other clients. The latter is unlabeled
since labels on other clients have no relevance to the current client’s data without any assumption. As in
FL, one client can only use data from other clients indirectly via model communication and aggregation,
the FL problem is at least as hard as SSL. Given the impossibility result for SSL [7] that the worst sample
79
complexity of SSL can only improve over supervised learning (corresponding to training with local data
only in FL) by a constant factor, we have that FL methods with no assumption on local data distributions
can only improve sample complexity over local training by at most a constant factor. Therefore, reasonable
assumptions on local data distributions are required for effective FL algorithms.
In this work, considering the multi-domain nature of data, we propose the following assumption on
local data distributions that each data sample (Xc,i, yc,i) on any client c is associated with an unobservable
hidden factor dc,i representing its data domain. We formalize the assumption as:
Assumption 1. All data samples are from K domains. Each sample on any client c is generated from the
joint data distribution parameterized by Θc = Θ = {Φ, K, C}:
P(Xc,i, yc,i; Θc)
=
X
K
dc,i=1
P(yc,i|Xc,i, dc,i; Φ)P(Xc,i|dc,i; K)C(dc,i)
=
X
K
dc,i=1
P(yc,i|Xc,i; ϕdc,i )P(Xc,i|dc,i; K)C(dc,i),
(6.4)
where Φ ∈ R
K×P1 and K ∈ R
K×P2 are matrices of parameters shared across clients. ϕdc,i = Φ[dc,i, :] is the
P1-dim parameter vector of the prediction model for the dc,i-th domain. C is the prior categorical distribution
of data domains. For the i-th sample, the data generation process is:
dc,i ∼ C, Xc,i ∼ P(Xc,i|dc,i; K),
yc,i ∼ P(yc,i|Xc,i; ϕdc,i ).
(6.5)
6.2.3 Methodology
In Figure 6.1, we provide an overview of SL-PFL. In the following text, we will introduce the local training
and the federated aggregation steps of SL-PFL in detail. We give the pseudocode of SL-PFL in Algorithm 3.
80
Algorithm 3 Sample-Level Prototypical Federated Learning (SL-PFL)
Input: client datasets D1:C, number of global domains K, number of client domains {k1, k2, . . . , kC},
number of communication rounds R
Output: θ, θmo, Φ, {C1:C}
for iterations r = 1, . . . , R do
Server broadcasts θ
r−1
, θr−1
mo , Φr−1
to sampled Cr ≤ C clients.
for client c in parallel over Cr clients do
V
′
c = fθ
r−1 mo,c
(Xc)
{(κc,i, mc,i)|i ∈ [kc]} = kc-means(V
′
c
).
Client sends back {(κc,i, mc,i)|i ∈ [kc]}.
end for
{κi
|i ∈ [K]} = K-means({κc,i|c ∈ Cr, i ∈ [kc]})
Estimate {ρi
|i ∈ [K]} with Equation 6.17.
for client c in parallel over Cr clients do
for x ∈ Xc do
v = fθ
r−1
c
(x), v′ = fθ
r−1 mo,c
(x)
Update Φ
r
c
, θr
c
, θr
c,mo with Equation 6.15 and Equation 6.16.
end for
Update Cc according to Equation 6.13.
Client sends back θ
r
c
, θr
mo,c, Φr
c
to the server.
end for
θ
r =
1
Cr
PCr
c=1 θ
r
c
; θ
r
mo =
1
Cr
PCr
c=1 θ
r
c,mo; Φr =
1
Cr
PCr
c=1 Φr
c
.
end for
θ = θ
R, θmo = θ
R
mo, Φ = ΦR.
return θ, θmo, Φ, {C1:C}
81
Server
Client 1
Client 2
Client 3
Client c …
FedAvg
0.05 0.01 0.85 0.09
0.01 0.10 0.02 0.87
(1)
(3)
(2)
(4)
(5)
(7)
(6)
Figure 6.1: Overview of SL-PFL. (1) Each client runs kc-means clustering of embeddings from the momentum
encoder and produces local prototypes. (2) The server collects local prototypes of clients and runs K-means
clustering to get global prototypes. (3) The server sends global prototypes back to each client. (4) Each
client embeds input with the encoder and calculates the prototypical (Equation 6.10) and the InfoNCE
(Equation 6.14) losses. (5) Each client conducts domain-specific prediction and ensembles the results based
on the estimated posterior domain distributions. Task-specific supervised loss is then calculated with
the ensembled prediction and the ground truth. With the supervised, prototypical, and InfoNCE losses
(Equation 6.15), each client optimizes local parameters via gradient descent. (6) The server aggregates local
models via FedAvg. (7) The server sends the aggregated models back to clients.
Local Training Combining Equation 6.3 and Equation 6.4, we have the equivalent optimization problem
as:
Θ∗ = arg min
Θ
1
C
X
C
c=1
1
nc
Xnc
i=1
Lc,i, where
Lc,i = − log X
K
dc,i=1
P(Xc,i, yc,i, dc,i; Θ)
= − log X
K
dc,i=1
P(yc,i|Xc,i; ϕdc,i )P(Xc,i|dc,i; K)C(dc,i).
(6.6)
82
As Equation 6.6 is hard to optimize directly, we apply the Expectation-Maximization (EM) algorithm
and minimize the following Q-function for each client and each data sample instead:
Q(Θ; Θold) = 1
C
X
C
c=1
1
nc
Xnc
i=1
Qc,i(Θ; Θold),
Qc,i(Θ; Θold) = −Edc,i∼P(dc,i|yc,i,Xc,i;Θold)
[log P(yc,i, Xc,i, dc,i; Θ)]
= −Edc,i∼P(dc,i|yc,i,Xc,i;Θold)
log P(yc,i|Xc,i; ϕdc,i )
| {z }
(1) supervised learning loss
+ log P(Xc,i|dc,i; K)
| {z }
(2) prototypical learning loss
+ log C(dc,i)
| {z }
(3) prior loss
(6.7)
Equation 6.7 describes the optimization target function for each client during the local training stage of
FL and contains three loss terms: (1) supervised learning loss that encourages the model to minimize the
prediction loss on training data; (2) prototypical learning loss that favors clustered representations; (3) prior
loss that modifies the estimation of prior domain distributions based on the observed data. In local training,
two encoder networks are used to extract the embedding of each data sample. The first one parameterized
by θ as vi = fθ(Xi), and the momentum encoder network parameterized by θmo. The latter is updated with
a lower learning rate during training, producing more stable data embeddings v
′
i = fθmo
(Xi), as shown
in [27, 48]. The local training procedure is composed of two steps described as follows (unless otherwise
stated, we omit the client ID footnote c since the local training only involves one client most of the time):
E-step In this step, we run kc-means clustering on features from the momentum encoder {v
′
i
, | i ∈ [N]}
to obtain kc local clusters, where kc is the specified number of local prototypes and may vary with
clients. Through the federated aggregation process described later, each client gets the globally aggregated
83
prototypes of K clusters K = {κk | k ∈ [K]} from the server. Then, we estimate the posterior distribution
of the domain factor di conditioning on the data as:
P(di
|Xi
, yi
; Θ) ∝ P(yi
|Xi
; ϕdi
)P(Xi
|di
; K)C(di).
(6.8)
Since P(yi
|Xi
; ϕdi
) and C(di) can be directly derived from the prediction losses and the parameter C, we
mainly discuss the term P(Xi
|di
; K). Following [48], we assume that the distribution of data sample
embeddings is an isotropic Gaussian surrounding prototypes K:
P(Xi
|di
; K) = exp(−
(vi − κdi
)
2
2σ
2
di
)/
X
K
k=1
exp(−
(vi − κk)
2
2σ
2
k
), (6.9)
where vi = fθ(Xi). We further require that all {vi} , {v
′
i
} , {κk} are L2-normalized, then Equation 6.9 can
be reformulated as:
P(Xi
|di
; K) = C ·
exp(vi
· κdi
/ρdi
)
PK
k=1 exp(vi
· κk/ρk)
, (6.10)
where C is a constant. ρk ∝ σ
2
k
is a variable describing the concentration level of the feature distribution
around the k-th cluster:
ρk =
Pmk
i=1 ∥v
′
i − κk∥2
mk log(mk + α)
, (6.11)
where α smooths the estimation and avoids excessively large values for clusters with few samples. It is
worth noting that ρk should consider all samples in the k-th cluster across clients, and directly calculating
Equation 6.11 is impractical under the FL setting. We will introduce the solution in the following federated
aggregation part.
While [48] estimates the posterior distribution as a one-hot categorical distribution with a probability
of 1 on the category given by K-means clustering, we find that the above soft form of estimation leads to
slightly better results, as we demonstrate in Section 6.3.5.
84
M-step We rewrite Equation 6.7 with the actual parameters to optimize (Φ, θ, θmo, C):
Qi(Φ, θ, θmo, C; Φold, θold, θmo,old, Cold)
= − Edi∼P(di;Xi,yi,Φold,K(θmo,old))
[log P(yi
|Xi
; ϕdi
) + log P(Xi
|di
; K(θmo)) + log C(di)] .
(6.12)
Combining Equations 6.8, 6.10 and 6.12, we derive the objective function of each data sample to minimize
in local training:
Qi(Θ; Θold) =
−
X
K
di=1 "
exp(s(vi
· κdi
)/ρdi
)
PK
j=1 exp (s(vi
· κj )/ρdj
)
· log P(yi
|Xi
; ϕdi
)
#
| {z }
Li,supervised
−
X
K
di=1 "
exp(s(vi
· κdi
)/ρdi
)
PK
j=1 exp (s(vi
· κj )/ρdj
)
· log exp(vi
· κdi
/ρdi
)
PK
k=1 exp(vi
· κk/ρk)
#
| {z }
Li,prototypical
+ DKL(P(di
|Xi
, yi
; Φold, K(θmo,old))∥C(di))
− H(P(di
|Xi
, yi
; Φold, K(θmo,old))),
(6.13)
where DKL and H stands for the Kullback–Leibler (KL) divergence and entropy, respectively. Notice that
the entropy of the posterior distribution is constant in the M-step and we only need to optimize the first
three terms. For C, we only need to find C that minimizes DKL, which is exactly the posterior distribution
from the E-step.
85
We optimize Φ, θ by minimizing the loss function composed of the first two terms from Equation 6.13.
In addition, we include the sample-wise InfoNCE loss [74] in order to better learn sample representations:
Li,InfoNCE = − log exp (vi
· v
′
i
/τ )
Pr
j=1 exp (vi
· v
′
j
/τ )
, (6.14)
where r is the number of negative samples (other data samples) and τ controls the concentration of the
feature distribution around each instance. As a result, the overall local training loss is formalized as:
L =
Xn
i=1
(Li,supervised + Li,prototypical + Li,InfoNCE). (6.15)
Φ, θ are optimized via stochastic gradient descent or alike methods, while θmo is updated with momentum
for stability:
Φ′ ← Φ − α∇ΦL, θ′ ← θ − α∇θL,
θ
′
mo ← βθmo + (1 − β)θ
′
.
(6.16)
Federated Aggregation The above local training procedure performs the clustering on local data samples.
While parameters Φ, θ, θmo can be aggregated with regular FL methods such as FedAvg, and C is kept
locally on each client, prototypes generated by local clustering must be aggregated properly to form a set
of global prototypes in FL. Here, we develop the aggregation procedure of local prototypes.
We first conduct local K-means on data embeddings of clients participating in the training round in
parallel. Each client then uploads the clustering centroids and numbers of samples of all clusters to the
central server. We then conduct K-means on centroids from all clients and use the result as the final
prototypes. We also modify Equation 6.11 for the global clustering on centroids:
ρk =
Pmk
i=1 m
(client)
i
κ
(client)
i − κk
2
Pmk
i=1 m
(client)
i
log(Pmk
i=1 m
(client)
i + α)
, (6.17)
86
where m
(client)
i
is the number of data samples in the local cluster with its centroid κ
(client)
i
. Notice that local
clustering on clients can have different numbers of clusters from the number of clusters in the final global
clustering, which corresponds to the data locality that each client may only own samples from a subset of
all data domains.
After the global clustering, the server sends globally clustered prototypes to previously selected clients
and optimizes parameters locally as described in the above local training procedure.
Inference In the inference stage on each client c, we do not have access to yc,i, and we instead infer the
prototype information solely from Xc,i, and derive the label distribution and its expectation as follows:
P(yc,i|Xc,i; Θ) = X
dc,i
P(dc,i|Xc,i; Θ)P(yc,i|Xc,i, dc,i; Θ),
E(yc,i|Xc,i; Θ) = X
dc,i
P(dc,i|Xc,i; Θ)E(yc,i|Xc,i, dc,i; Θ),
(6.18)
where
P(dc,i|Xc,i; Θ) = P(Xc,i|dc,i; Θ)Cc(dc,i)
PK
d=1 P(Xc,i|d; Θ)Cc(d)
∝ P(Xc,i|dc,i; Θ)Cc(dc,i).
(6.19)
When the inference is conducted on unseen clients that do not take part in training, we do not know
the prior distribution of domains and assume that it is a uniform distribution.
Comparing Equations 6.8 and 6.19, we notice that there exists a discrepancy between training and
inference in the form of posterior sample domain distribution due to the lack of labels in inference time. We
will show later in the ablation study that such a discrepancy hinders the model’s performance. Therefore,
we use Equation 6.19 in both training and inference stages as the form of the posterior sample domain
distribution.
87
6.3 Experiments
6.3.1 Settings
Tasks and Datasets We evaluate the performance of SL-PFL as well as baseline methods on three
real-world datasets with corresponding tasks. We provide statistics of the above datasets in Table 6.1.
Temperature regression with Shifts Weather dataset [64]. The Shifts Weather dataset contains
10M records uniformly distributed between Sep 1, 2018, and Sep 1, 2019. These records are collected from
8K weather stations located across the world. The goal of this task is to predict the air temperature at a
certain location and timestamp.
We select the in-domain part of the dataset, which contains the same three climate types (Tropical, Dry,
Mild Temperature) in training, validation, and test sets. We follow the provided data split and the ratio
is 83.7%/1.3%/15% for training, validation, and test respectively. We partition each split to 20 clients by
selecting two out of three climate types and sampling a subset of records in selected types for each client.
The number of samples per client follows a power law as in existing works [52, 94]. The first 10 clients
(“Training Clients”) allocated with training splits participate in federated training, while the test splits on
both training clients and the remaining clients (“Unseen Clients”) are used in evaluation.
The climate type is excluded from input features since in the real-world case domain information is
usually inaccessible. We use a multiple-layer perceptron (MLP) model with 3 hidden layers of size 128 as
the base prediction model.
Image classification with MiniDomainNet dataset [115]. MiniDomainNet is a widely used benchmark dataset for multi-source domain adaptation. It contains 140,006 images of 126 classes from 4 domains
(Clipart, Painting, Real, and Sketch). Similar to the partition of Shifts Weather, each client contains 2 out of
4 data domains, and the distribution of numbers of samples in each data domain across clients follows a
power law. Similar to Shifts Weather, 10 out of 20 clients participate in federated training and we evaluate
88
on all clients. Explicit domain indicators are excluded from input features. The base prediction model is a
5-layer CNN model.
Mortality prediction with eICU dataset [75]. The eICU dataset is a large multi-center critical care
database collected from 200,859 patients in 208 hospitals. The task is to use drug features to predict the
patient’s mortality when discharged from the ICU. For eICU, we partition the dataset by the hospital ID of
each record to simulate the real-world scenario that hospitals do not share data with others due to data
regulations. For each client, we randomly split its private data to training/validation/test splits with the
ratio 60%/20%/20%. We also drop clients that have fewer than 100 records or 6 samples with positive labels.
Following the split from eICU, 91 out of 115 hospitals participate in federated training and the performance
on all hospitals are evaluated. We use a 3-layer MLP with hidden layers of sizes 256, 128, and 64 as the
prediction model.
Synthetic Dataset Partition While the eICU dataset includes the hospital ID of each patient, based on
which we can partition the data, the Shifts Weather and the MiniDomainNet datasets do not come with
organization identifiers and we need to manually partition them to clients for simulating the cross-silo
federated learning scenario with intra-client data heterogeneity. For each dataset, we first partition the
data by domain. We select the climate type (Tropical, Dry, Mild Temperature) and the domain label (Clipart,
Painting, Real, Sketch) as the domain indicators of Shifts Weather and MiniDomainNet respectively. Then
we use a power law to decide the number of samples of each domain for each client. This is to ensure
that the ratio of data samples from different domains also varies with clients and to increase the data
heterogeneity. We visualize the number of samples in each domain of our partition in Figure 6.2.
Baselines We evaluate the performance of SL-PFL against the following baselines, which cover a wide
range of recent works on FL: FedAvg, FedProx, LG-FedAvg, PerFedAvg, FedPer, FedRep, Ditto, FedEM,
FedFOMO, FedPCL, and FedNH. For FedAvg and FedProx, we also use their fine-tuned variants as baselines
89
0 1 2 3 4 5 6 7 8 9
Client
0
100000
200000
300000
400000
500000
600000
#Samples
dry
mild temperate
tropical
(a) Training Sets on Training Clients, Shifts Weather
0 1 2 3 4 5 6 7 8 9
Client
0
5000
10000
15000
20000
25000
30000
35000
#Samples
clipart
painting
real
sketch (b) Training Sets on Training Clients, MiniDomainNet
0 1 2 3 4 5 6 7 8 9
Client
0
20000
40000
60000
80000
100000
#Samples
dry
mild temperate
tropical
(c) Test Sets on Training Clients, Shifts Weather
0 1 2 3 4 5 6 7 8 9
Client
0
100
200
300
400
500
#Samples
clipart
painting
real
sketch
(d) Test Sets on Training Clients, MiniDomainNet
10 11 12 13 14 15 16 17 18 19
Client
0
5000
10000
15000
20000
#Samples
dry
mild temperate
tropical
(e) Test Sets on Unseen Clients, Shifts Weather
10 11 12 13 14 15 16 17 18 19
Client
0
10
20
30
40
50
60
70
80
#Samples
clipart
painting
real
sketch
(f) Test Sets on Unseen Clients, MiniDomainNet
Figure 6.2: Data Partition Visualization of Shifts Weather and MiniDomainNet.
90
Table 6.1: Statistics of datasets.
Dataset Shifts Weather MiniDomainNet eICU
Task Regression 126-Class
Classification
Binary
Classification
#Clients
Training/Unseen 10/10 10/10 91/24
#Samples
Per Client 34846∼692110 625∼40644 124∼6762
#Samples 3349178 140006 126389
(denoted as FedAvg∗
and FedProx∗
). The fine-tuned variant is optimized for one epoch on test clients with
local training data in evaluation. We provide details about hyperparameter tuning in Section 6.3.2.
Each client is trained for one epoch in one communication round. We train for 150 global rounds on
Shifts Weather and 200 global rounds on MiniDomainNet and eICU. We report the performance on the test
set of the best-performing model on the validation set. We repeat each experiment 3 times and report the
mean and standard error of metrics.
6.3.2 Hyperparameter Tuning
In this part, we provide details about hyperparameter tuning for baseline and our proposed methods.
Since the number of baselines is large, we only search for optimal hyperparameters on FedAvg, FedRep
and FedPer, and reuse the best settings for other methods except for FedProx, as we observe that the effect
of these hyper-parameters is similar across these methods. We set the learning rate to 10−3
for Shifts
Weather and to 10−4
for MiniDomainNet and eICU. The learning rate for FedProx is fine-tuned separately,
and we use 10−2
for Shifts Weather and MiniDomainNet, and 10−3
for eICU. All baselines and SL-PFL are
optimized with the Adam optimizer unless otherwise stated. For our proposed SL-PFL, we set the client
domain number k = 2 and the global domain number K = 3 on the Shifts Weather dataset, k = K = 4 on
the MiniDomainNet dataset, and k = K = 8 on the eICU dataset based on fine-tuning with validation data.
We also study the effect of k and K in Section 6.3.6.
91
General Hyperparameters
• Learning rates and optimizers. We first searched for optimal learning rates and optimizers for each
dataset on FedAvg, FedRep and FedPer. We searched the optimal learning rate in [10−4
, 10−3
, 5×10−3
,
10−2
, 5×10−2
, 10−1
] with Adam, SGD and SGD with momentum (γ=0.9) optimizers respectively. We
observed that the effect of learning rates and optimizers is similar across them. The best-performing
learning rates are 10−3
for Shifts Weather and 10−4
for MiniDomainNet and eICU, all with Adam
optimizer. Since the numbers of possible combinations of the remaining baselines, candidate learning
rates, and optimizers are large, we reused the best learning rate/optimizer combinations for the rest
of the baselines except FedProx. We fine-tuned the learning rate for FedProx separately, as FedProx
uses its own optimization mechanism. For FedProx, we searched for the optimal learning rate in
[10−4
, 10−3
, 5 × 10−3
, 10−2
, 5 × 10−2
, 10−1
]. After fine-tuning, we used lr=10−2
for Shifts Weather
and MiniDomainNet, and lr=1e-3 for eICU.
• Global Epochs. We use early stopping on all clients for all methods and do not specify a certain
number of global epochs. In addition, we set the maximum global epochs to 150 for Shifts Weather
and 200 for MiniDomainNet and eICU to ensure that all methods can converge. For each method, we
keep the model performing the best on the private validation dataset for each client respectively, and
record its performance on the private test dataset. We report the average test performance across
clients as result. This simulates the scenario where each client participating in the FL training selects
the model version performing the best on its own data. .
• Others. Other common hyperparameters across all methods include: batch size=128, number of local
epochs=1, client selection ratio=1.0 (all clients are selected in each round).
Model-Specific Hyperparameters
92
• FedProx: We have adjusted the parameter µ to one of the following values: [1e-3, 1e-2, 1e-1, 1, 10],
and found that only under the setting µ = 1 can the training converge stably on all datasets. Other
hyperparameters are set as weight decay=0.01, momentum=0.0, and dampening=0.0.
Architectures of Local Models For both Shifts Weather and eICU datasets, we use an MLP with 3
hidden layers. For the Shifts Weather dataset, each hidden layer has a dimension of 128. For the eICU
dataset, the hidden dimensions are 256, 128, and 64 respectively.
We use a 5-layer convolutional neural network (CNN) as the prediction model for MiniDomainNet. The
architecture is as follows:
Table 6.2: Architecture of the 5-layer CNN used in MiniDomainNet.
Layer Type Size
Conv2D 5 × 5, 32
ReLU -
MaxPool2D 2 × 2
Conv2D 5 × 5, 64
ReLU -
MaxPool2D 2 × 2
Linear 64 × 5 × 5, 2048
ReLU -
Linear 126
6.3.3 Prediction Results
We report the average prediction performance metrics over training clients (clients participating in
federated training), unseen clients (clients appearing in evaluation but not in training), and all clients
of each dataset respectively in Table 6.3. For the regression task on Shifts Weather, we use the Rooted
Mean Square Error (RMSE) as the metric. For the multiclass classification task on MiniDomainNet and the
binary classification task on eICU, we report the Accuracy and the Area Under the Precision-Recall Curve
(AUPRC) respectively.
93
Table 6.3: Average prediction errors across clients.
Method
Shifts Weather
(RMSE ↓)
MiniDomainNet
(Accuracy ↑)
eICU
(AUPRC ↑)
Training
Clients
Unseen
Clients
All
Clients
Training
Clients
Unseen
Clients
All
Clients
Training
Clients
Unseen
Clients
All
Clients
Local 1.991±0.002 3.510±0.314 2.855±0.192 10.3%±1.1% 12.4%±1.4% 11.3%±1.2% 24.5%±1.2% 15.3%±3.0% 22.6%±1.4%
FedAvg [66] 1.932±0.001 1.900±0.001 1.917±0.001 21.2%±0.7% 25.9%±1.3% 23.6%±0.5% 27.1%±0.4% 26.2%±0.9% 26.9%±0.5%
FedAvg∗
1.948±0.003 1.916±0.003 1.932±0.003 19.5%±0.6% 24.0%±1.2% 21.7%±0.3% 27.4%±0.3% 27.2%±0.8% 27.3%±0.1%
FedProx [52] 3.090±0.021 2.962±0.025 3.027±0.023 9.5%±0.7% 13.1%±1.4% 11.3%±0.8% 15.2%±0.9% 14.4%±1.3% 15.1%±1.0%
FedProx∗
2.322±0.010 2.720±0.040 2.529±0.021 7.3%±1.1% 8.7%±0.3% 8.0%±0.4% 27.1%±0.9% 24.4%±1.0% 26.6%±0.8%
LG-FedAvg [57] 2.008±0.033 2.116±0.011 2.063±0.022 10.2%±0.4% 16.6%±0.6% 13.4%±0.4% 25.0%±0.9% 19.7%±3.8% 23.9%±1.5%
PerFedAvg [21] 2.497±0.014 5.606±0.424 4.340±0.277 18.4%±0.3% 24.1%±1.4% 21.3%±0.9% 27.3%±0.2% 27.7%±1.4% 27.4%±0.2%
FedRep [18] 1.967±0.002 1.934±0.003 1.951±0.001 17.2%±0.8% 17.6%±2.2% 17.4%±0.9% 27.7%±0.6% 26.2%±1.2% 27.4%±0.7%
FedPer [2] 1.950±0.002 1.916±0.002 1.937±0.006 13.6%±0.6% 23.5%±2.2% 18.5%±1.1% 28.3%±0.3% 27.1%±1.5% 28.0%±0.5%
Ditto [50] 1.958±0.007 1.916±0.006 1.937±0.006 10.0%±0.3% 24.4%±0.7% 17.2%±0.2% 25.3%±0.4% 27.0%±1.7% 25.6%±0.9%
FedEM [65] 2.124±0.014 2.129±0.055 2.127±0.030 20.9%±0.6% 25.3%±0.9% 23.1%±0.5% 27.3%±0.3% 26.3%±0.8% 27.1%±0.1%
FedFOMO [110] 1.989±0.001 - - 20.8%±0.4% - - 27.4%±0.3% - -
FedPCL [95] > 10 > 10 > 10 22.0%±0.5% 27.6%±1.3% 24.8%±0.8% 27.0%±0.5% 22.0%±0.2% 25.9%±0.5%
FedNH [20] > 10 > 10 > 10 19.0%±0.6% 22.0%±0.5% 20.5%±0.6% 26.6%±0.3% 22.2%±1.9% 25.7%±0.7%
SL-PFL 1.909±0.003 1.876±0.002 1.892±0.003 22.0%±0.5% 29.5%±0.3% 25.8%±0.4% 29.3%±1.2% 24.0%±1.0% 28.1%±1.1%
In particular, we adapt FedPCL and FedNH methods for the regression task by discretizing the realvalued labels with a step size of 1.0. Notice that FedFOMO does not specify how to assign model weights
for unseen clients in training and thus we evaluate its performance on training clients only.
We have the following observations:
(1) Results show that compared to various existing global and personalized FL baselines, our proposed
method SL-PFL achieves the best prediction performance on all datasets and various types of tasks, including
regression, binary classification, and multi-class classification. On Shifts Weather and MiniDomainNet,
SL-PFL outperforms all baselines across both training clients and unseen clients. On eICU, SL-PFL also
achieves the best overall prediction performance.
(2) The basic FedAvg and its personalized variant achieve performance close to or even better than
other personalized baselines and serve as strong baselines across tasks.
(3) Prototype-based baselines (FedPCL, FedNH) that rely on categorical labels outperform most of the
other baselines in the multi-label classification task but not in binary classification and regression tasks.
94
6.3.4 Ablation Study: Effect of Loss Function Components
We conduct an ablation study on the effect of each component of Equation 6.15 with the Shifts Weather
dataset. As the prediction task is a supervised learning task, we always maintain Li,supervised in the overall
loss and remove one or both of Li,prototypical and Li,InfoNCE. Prediction errors on clients of these loss
variants are shown in Table 6.4.
Table 6.4: Comparison of the effect of different loss function components across clients.
Loss Shifts Weather
(RMSE ↓)
Li,supervised 2.072±0.006
Li,supervised + Li,prototypical 2.099±0.018
Li,supervised + Li,InfoNCE 1.939±0.023
Li,supervised + Li,prototypical + Li,InfoNCE 1.909±0.003
Results show that adding both prototypical and InfoNCE losses to the supervised task-specific loss (as
defined in Equation 6.15) gives the best performance. We also notice that only adding the prototypical loss
leads to even inferior performance, indicating that representations trained with the InfoNCE loss are more
distinguishable and boost the performance.
6.3.5 Ablation Study: Form of the Posterior Domain Distribution
Table 6.5: Comparison of the effect of different posterior domain distribution forms on the average prediction
errors across clients.
Method SL-PFL†
SL-PFL‡
SL-PFL
Shifts Weather
(RMSE ↓)
1.930(0.006) 2.013(0.007) 1.909(0.003)
Here we validate the design of the posterior data domain distribution and compare its performance
with the following variants: (1)SL-PFL†
: use the one-hot categorical distribution with probability 1 on the
category given by K-means clustering as in [48]; (2)SL-PFL‡
: use Equation 6.8 in training and Equation 6.19
in inference, instead of using Equation 6.19 for both. Results in Table 6.5 demonstrate that: (1) the soft
95
Table 6.6: Average prediction performance (RMSE) across clients on Shifts Weather with varying kclient
and K.
kclient \ K 2 3 4
2 1.928 1.909 1.957
3 - 1.946 2.120
4 - - 1.987
form of posterior distribution leads to slightly better performance than using the clustering results as the
posterior, which validates our design in Section 6.2.3; (2) only using features to infer the posterior domain
distribution mitigates the discrepancy between training and test and significantly improves the model’s
performance on test data.
6.3.6 Ablation Study: Effect of Numbers of Local and Global Domains in SL-PFL
In this section, we discuss the effect of the following hyperparameters in SL-PFL: client domain numbers
{k1, k2, . . . , kC} and the global domain number K. We vary these numbers in [2, 3, 4] and compare their
prediction performance on the Shift Weather dataset. Since evaluating all 3
C combinations of client domain
numbers is impractical, we keep kclient = k1 = · · · = kC and kclient ≤ K.
Table 6.6 shows results of the average RMSE of clients on the Shifts Weather dataset with different
combinations of kclient and K values. The ground truth values of kclient and K are marked in bold
(represented by the number of climate types on each client and overall, respectively). “-” stands for invalid
combinations where K < kclient.
We can see that when both kclient and K are set to match the ground truth numbers of per-client data
domains and overall data domains, SL-PFL can achieve the lowest prediction error. Meanwhile, the other
combination k = K = 2 also provides competitive prediction accuracy (RMSE 1.928) compared to the
best-performing baseline (FedAvg with RMSE 1.932). Nevertheless, improper choices of k and K degenerate
the prediction performance.
96
6.3.7 Evaluation of Potential Privacy Leaks From Centroids
Table 6.7: Accuracy of membership inference with centroids on Shifts Weather.
Client ID #training
samples
Infer from the embedding
by encoder
Infer from the embedding
by momentum encoder Random
1 377925 9.68% 9.77% 13.47%
2 368153 25.81% 26.31% 13.14%
3 552322 9.01% 9.07% 19.71%
4 579012 9.74% 10.04% 20.66%
5 116859 16.04% 15.79% 4.17%
6 329683 13.57% 13.58% 11.77%
7 255911 12.88% 12.74% 9.13%
8 29165 14.17% 13.74% 1.04%
9 57256 15.78% 15.87% 2.04%
10 136379 14.73% 14.27% 4.87%
Overall 2802035 13.11% 13.22% 14.38%
0 20 40 60 80 100 120 140
Epoch
1.8
1.9
2.0
2.1
2.2
2.3
2.4
2.5
2.6
RMSE
FedAvg
FedAvg*
FedProx
FedProx*
LG-FedAvg
PerFedAvg
FedPer
FedRep
Ditto
FedEM
FedFOMO
SL-PFL
(a) Shifts Weather
0 20 40 60 80 100 120 140 160 180 200
Epoch
0.00
0.05
0.10
0.15
0.20
Acc
FedAvg
FedAvg*
FedProx
FedProx*
LG-FedAvg
PerFedAvg
FedPer
FedRep
Ditto
FedEM
FedFOMO
FedPCL
FedNH
SL-PFL
(b) MiniDomainNet
0 20 40 60 80 100 120 140 160 180 200
Epoch
0.10
0.15
0.20
0.25
0.30
AUPRC
FedAvg
FedAvg*
FedProx
FedProx*
LG-FedAvg
PerFedAvg
FedPer
FedRep
Ditto
FedEM
FedFOMO
FedPCL
FedNH
SL-PFL
(c) eICU
Figure 6.3: Test Error Curves of Baselines and SL-PFL.
We evaluate whether the clustering centroids introduced in SL-PFL bring obvious information that can
be used to conduct membership inference attack. We assume that the attacker has access to all information
visible to the central server and consider the following strategy: given a data sample, the attacker first
encodes it with either (1) the encoder or (2) the momentum encoder, then finds the embedding’s nearest
centroid. The inferred client having the given data sample is the client providing the nearest centroid.
On the Shifts Weather dataset, we evaluate the accuracy of the membership inference when the training
data is directly used in the attack, which can be considered as the worse case, and compare it with a random
baseline. Results are summarized in 6.7.
97
We assume that the attacker also knows the number of training examples, and the random guess is
based on the number of samples of each client instead of being from a uniform distribution. For client i
with ni training samples, the random guess has a probability of P ni
C
c=1 nc
to guess i as the source client.
We report the overall accuracy of the random guess as the expectation of accuracy when we sample the
example uniformly from all training examples: PC
i=1
P ni
C
c=1 nc
2
We observe that the membership inference attack based on centroids produces membership inference
accuracy higher than the random baseline on client 2 as well as clients with fewer samples (5-10). However,
even the worst case (directly using data exactly the same as training data on client 2) has only about 25%
accuracy on membership inference. Meanwhile, the overall membership inference accuracy is still close to
a random guessing baseline. Therefore, the centroids from local clustering in SL-PFL do not provide an
effective information source for membership inference attack.
6.3.8 Convergence Behavior of SL-PFL
We report the performance on the test set of the best-performing model on the validation set with both Shifts
Weather and eICU datasets during the training process in Figure 6.3. Compared to baselines, SL-PFL shows
similar convergence speed while achieving the best prediction performance on both datasets.
6.3.9 Computation And Communication Costs of SL-PFL
The computation overhead contains (1) extra training and inference cost from K domain-specific prediction
models instead of 1 prediction model; (2) the computation cost of local KMeans O(T nckcd) on each client
c, where T is the number of local KMeans rounds, nc is the number of local data samples, kc is the number
of local prototypes, d is the embedding dimension; (3) the computation cost of global KMeans on the
server O(T
′ P
c
kcKd), where T
′
is the number of global KMeans rounds and K is the number of global
prototypes.
98
The communication cost in each global round of training is
O(
X
c
kcd + CKd + 2C(KM + 2Menc)),
where C is the number of clients, M is the size of prediction model parameters, and Menc is the size of
encoder/momentum encoder parameters.
6.4 Conclusion
We propose Sample-Level Prototypical Federated Learning (SL-PFL), which is based on the fine-grained
sample-level data distribution factorization and provides each data sample with a personalized model.
Inspired by recent advances in prototypical representation learning with centralized data, we develop
a federated semi-supervised prototypical contrastive learning algorithm to train data embeddings with
prototypes reflecting domain information, as well as domain-specific prediction models. Future directions
along our work include automatically recovering the number of local and global domains instead of relying
on hyperparameter tuning, and integrating domain adaptation to address data domains unseen in training
data during the test time.
99
Chapter 7
Conclusions
The dissertation presents a panorama of my works on building trustworthy spatiotemporal prediction
models with proposed approaches in physics informed machine learning and federated learning. Starting
from the success of data-driven models in spatiotemporal prediction, works in the dissertation address
three main challenges (robustness, interpretability, privacy preservation) in trustworthiness from both
vertical and horizontal dimensions of trustworthiness. With contributions in these aspects, the dissertation
formulate a cohesive framework that integrates trustworthiness in spatiotemporal prediction models.
Vertical trustworthiness With the focus on improving model generalizability and interpretability, the
dissertation presents contributions to vertical trustworthiness through the development of novel physics
informed machine learning techniques, including the Physics-Aware Difference Graph Networks (PA-DGN)
that is inspired by the numerical partial differential equation solvers and the Spaiotemporal Koopman
Multi-Resolution Network (ST-KMRN) that integrates the Koopman Theory into long term forecasting.
With the aid of external knowledge represented as physics constraints, spatiotemporal models not only
generalize to horizons and resolutions beyond the training data, but also provide insights about properties
of input data such as periodicity.
100
Horizontal Trusthworthiness In terms of strengthening the data privacy preservation, the dissertation
proposes a federated learning spatiotemporal prediction model - Cross-Node Federated Graph Neural
Network (CNFGNN) that respects privacy and data localization constraints while leveraging graph neural
networks for modeling complex spatiotemporal dependencies across clients. Enhanced preservation of data
privacy enables spatiotemporal prediction models to utilize a broader range of data providers.
Comprehensive Trusthworthiness The dissertation demonstrates the Sample-Level Prototypical Federated Learning (SL-PFL) as a solution to achieving comprehensive trustworthiness from both vertical and
horizontal dimensions. By informing the model of prior knowledge in data distribution, SL-PFL can address
the challenges of both inter-client and intra-client data heterogeneity and provide fine-grained personalized
model for each data sample under the data localization constraint.
Several future research directions emerge with the recent development in knowledge representation and
federated learning. (1) Flexible and universal selection and integration of external knowledge in machine
learning models. Rapid development of large language models provide a universal solution to extracting
and representation human knowledge recorded in natural language. Exploiting the inherent knowledge
in pretrained large language models leads to a flexible and universal solution to constructing knowledge
informed models and remains exploring. (2) Improved client model aggregation and personalization in
federated learning. Due to the data localization and the resulting non identical and independent data
distribution, improved utilization of dependency and similarity across clients and refined adaptation to local
data in deployment call for further development. (3) Optimized knowledge integration and aggregation
in knowledge informed federated learning. Integrating external knowledge to address data heterogeneity
and extracting aggregated knowledge discovered from the client side forms the flywheel of continuously
improving federated learning models, and both directions await further exploration.
In conclusion, the dissertation explores and expands the boundary of trustworthy machine learning
models and its application in spatiotemporal prediction tasks. By integrating trustworthiness in both
101
vertical and horizontal dimensions in data-driven models, works presented in the dissertation improve
the generalizability, interpretability and privacy preservation property of prior works, and lay a solid
foundation for both real-world application of machine learning in critical domains and future research of
model trustworthiness.
102
Bibliography
[1] Nina Amenta and Yong Joo Kil. “Defining point-set surfaces”. In: ACM Transactions on Graphics
(TOG). Vol. 23. ACM. 2004, pp. 264–270.
[2] Manoj Ghuhan Arivazhagan, Vinay Aggarwal, Aaditya Kumar Singh, and Sunav Choudhary.
“Federated learning with personalization layers”. In: arXiv preprint arXiv:1912.00818 (2019).
[3] Omri Azencot, N Benjamin Erichson, Vanessa Lin, and Michael W Mahoney. “Forecasting
sequential data using consistent Koopman autoencoders”. In: ICML. 2020.
[4] Lei Bai, Lina Yao, Can Li, Xianzhi Wang, and Can Wang. “Adaptive Graph Convolutional Recurrent
Network for Traffic Forecasting”. In: Advances in Neural Information Processing Systems 33 (2020).
[5] Peter Battaglia, Razvan Pascanu, Matthew Lai, Danilo Jimenez Rezende, et al. “Interaction networks
for learning about objects, relations and physics”. In: Advances in neural information processing
systems. 2016, pp. 4502–4510.
[6] Peter W Battaglia, Jessica B Hamrick, Victor Bapst, Alvaro Sanchez-Gonzalez, Vinicius Zambaldi,
Mateusz Malinowski, Andrea Tacchetti, David Raposo, Adam Santoro, Ryan Faulkner, et al.
“Relational inductive biases, deep learning, and graph networks”. In: arXiv preprint arXiv:1806.01261
(2018).
[7] Shai Ben-David, Tyler Lu, and Dávid Pál. “Does Unlabeled Data Provably Help? Worst-case
Analysis of the Sample Complexity of Semi-Supervised Learning.” In: COLT. 2008, pp. 33–44.
[8] Samy Bengio, Oriol Vinyals, Navdeep Jaitly, and Noam Shazeer. “Scheduled sampling for sequence
prediction with recurrent neural networks”. In: Advances in Neural Information Processing Systems.
2015, pp. 1171–1179.
[9] Emmanuel de Bezenac, Arthur Pajot, and Patrick Gallinari. “Deep Learning for Physical Processes:
Incorporating Prior Scientific Knowledge”. In: International Conference on Learning Representations.
2018. url: https://openreview.net/forum?id=By4HsfWAZ.
[10] Miles Brundage, Shahar Avin, Jasmine Wang, Haydn Belfield, Gretchen Krueger, Gillian Hadfield,
Heidy Khlaaf, Jingying Yang, Helen Toner, Ruth Fong, et al. “Toward trustworthy AI development:
mechanisms for supporting verifiable claims”. In: arXiv preprint arXiv:2004.07213 (2020).
103
[11] Defu Cao, Yujing Wang, Juanyong Duan, Ce Zhang, Xia Zhu, Congrui Huang, Yunhai Tong,
Bixiong Xu, Jing Bai, Jie Tong, et al. “Spectral temporal graph neural network for multivariate
time-series forecasting”. In: Advances in neural information processing systems 33 (2020),
pp. 17766–17778.
[12] Michael B Chang, Tomer Ullman, Antonio Torralba, and Joshua B Tenenbaum. “A compositional
object-based approach to learning physical dynamics”. In: arXiv preprint arXiv:1612.00341 (2016).
[13] Tian Qi Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. “Neural ordinary
differential equations”. In: Advances in neural information processing systems. 2018, pp. 6571–6583.
[14] Weiqi Chen, Ling Chen, Yu Xie, Wei Cao, Yusong Gao, and Xiaojie Feng. “Multi-range attentive
bicomponent graph convolutional network for traffic forecasting”. In: Proceedings of the AAAI
Conference on Artificial Intelligence. Vol. 34. 2020, pp. 3529–3536.
[15] Yunjin Chen, Wei Yu, and Thomas Pock. “On learning optimized reaction diffusion processes for
effective image restoration”. In: Proceedings of the IEEE conference on computer vision and pattern
recognition. 2015, pp. 5261–5269.
[16] Kyunghyun Cho, Bart van Merriënboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares,
Holger Schwenk, and Yoshua Bengio. “Learning Phrase Representations using RNN
Encoder–Decoder for Statistical Machine Translation”. In: Proceedings of the 2014 Conference on
Empirical Methods in Natural Language Processing (EMNLP). 2014, pp. 1724–1734.
[17] Junyoung Chung, Caglar Gulcehre, KyungHyun Cho, and Yoshua Bengio. “Empirical evaluation of
gated recurrent neural networks on sequence modeling”. In: arXiv preprint arXiv:1412.3555 (2014).
[18] Liam Collins, Hamed Hassani, Aryan Mokhtari, and Sanjay Shakkottai. “Exploiting Shared
Representations for Personalized Federated Learning”. In: International Conference on Machine
Learning. 2021.
[19] Keenan Crane. “Discrete differential geometry: An applied introduction”. In: Notices of the AMS,
Communication (2018).
[20] Yutong Dai, Zeyuan Chen, Junnan Li, Shelby Heinecke, Lichao Sun, and Ran Xu. “Tackling Data
Heterogeneity in Federated Learning with Class Prototypes”. In: AAAI Conference on Artificial
Intelligence. 2023.
[21] Alireza Fallah, Aryan Mokhtari, and Asuman Ozdaglar. “Personalized federated learning: A
meta-learning approach”. In: Advances in Neural Information Processing Systems (2020).
[22] Xu Geng, Yaguang Li, Leye Wang, Lingyu Zhang, Qiang Yang, Jieping Ye, and Yan Liu.
“Spatiotemporal multi-graph convolution network for ride-hailing demand forecasting”. In: 2019
AAAI Conference on Artificial Intelligence (AAAI’19). 2019.
[23] Avishek Ghosh, Jichan Chung, Dong Yin, and Kannan Ramchandran. “An Efficient Framework for
Clustered Federated Learning”. In: Advances in Neural Information Processing Systems 33 (2020).
104
[24] Justin Gilmer, Samuel S Schoenholz, Patrick F Riley, Oriol Vinyals, and George E Dahl. “Neural
message passing for quantum chemistry”. In: Proceedings of the 34th International Conference on
Machine Learning-Volume 70. JMLR. org. 2017, pp. 1263–1272.
[25] Vincent Le Guen and Nicolas Thome. “Disentangling physical dynamics from unknown factors for
unsupervised video prediction”. In: Proceedings of the IEEE/CVF Conference on Computer Vision and
Pattern Recognition. 2020, pp. 11474–11484.
[26] Will Hamilton, Zhitao Ying, and Jure Leskovec. “Inductive representation learning on large graphs”.
In: Advances in neural information processing systems. 2017, pp. 1024–1034.
[27] Kaiming He, Haoqi Fan, Yuxin Wu, Saining Xie, and Ross Girshick. “Momentum contrast for
unsupervised visual representation learning”. In: Proceedings of the IEEE/CVF Conference on
Computer Vision and Pattern Recognition. 2020, pp. 9729–9738.
[28] Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. “Deep residual learning for image
recognition”. In: Proceedings of the IEEE conference on computer vision and pattern recognition. 2016,
pp. 770–778.
[29] Geoffrey Hinton, Li Deng, Dong Yu, George E Dahl, Abdel-rahman Mohamed, Navdeep Jaitly,
Andrew Senior, Vincent Vanhoucke, Patrick Nguyen, Tara N Sainath, et al. “Deep neural networks
for acoustic modeling in speech recognition: The shared views of four research groups”. In: IEEE
Signal processing magazine 29.6 (2012), pp. 82–97.
[30] Kurt Hornik, Maxwell Stinchcombe, and Halbert White. “Multilayer feedforward networks are
universal approximators”. In: Neural networks 2.5 (1989), pp. 359–366.
[31] Wenbing Huang, Tong Zhang, Yu Rong, and Junzhou Huang. “Adaptive sampling towards fast
graph representation learning”. In: Advances in neural information processing systems. 2018,
pp. 4558–4567.
[32] Zijie Huang, Yizhou Sun, and Wei Wang. “Learning Continuous System Dynamics from
Irregularly-Sampled Partial Observations”. In: Advances in Neural Information Processing Systems 33
(2020).
[33] Valerii Iakovlev, Markus Heinonen, and Harri Lähdesmäki. “Learning continuous-time PDEs from
sparse data with graph neural networks”. In: International Conference on Learning Representations.
2020.
[34] Chiyu Max Jiang, Jingwei Huang, Karthik Kashinath, Prabhat, Philip Marcus, and
Matthias Niessner. “Spherical CNNs on Unstructured Grids”. In: International Conference on
Learning Representations. 2019. url: https://openreview.net/forum?id=Bkl-43C9FQ.
[35] Peter Kairouz, H Brendan McMahan, Brendan Avent, Aurélien Bellet, Mehdi Bennis,
Arjun Nitin Bhagoji, Keith Bonawitz, Zachary Charles, Graham Cormode, Rachel Cummings, et al.
“Advances and open problems in federated learning”. In: arXiv preprint arXiv:1912.04977 (2019).
105
[36] Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank Reddi, Sebastian Stich, and
Ananda Theertha Suresh. “Scaffold: Stochastic controlled averaging for federated learning”. In:
International Conference on Machine Learning. PMLR. 2020, pp. 5132–5143.
[37] Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri, Sashank J Reddi, Sebastian U Stich, and
Ananda Theertha Suresh. “Scaffold: Stochastic controlled averaging for federated learning”. In:
Proceedings of the 37th International Conference on Machine Learning. 2020.
[38] Anuj Karpatne, William Watkins, Jordan Read, and Vipin Kumar. “Physics-guided neural networks
(pgnn): An application in lake temperature modeling”. In: arXiv preprint arXiv:1710.11431 (2017).
[39] I Kevrekidis, Clarence W Rowley, and M Williams. “A kernel-based method for data-driven
Koopman spectral analysis”. In: Journal of Computational Dynamics 2.2 (2016), pp. 247–265.
[40] Thomas N Kipf, Ethan Fetaya, Kuan-Chieh Wang, Max Welling, and Richard S Zemel. “Neural
Relational Inference for Interacting Systems”. In: ICML. 2018.
[41] Thomas N Kipf and Max Welling. “Semi-supervised classification with graph convolutional
networks”. In: Proceedings of the International Conference on Learning Representations (ICLR). 2017.
[42] Thomas N. Kipf and Max Welling. “Semi-Supervised Classification with Graph Convolutional
Networks”. In: International Conference on Learning Representations (ICLR). 2017.
[43] Bernard O Koopman. “Hamiltonian systems and transformation in Hilbert space”. In: Proceedings of
the national academy of sciences of the united states of america 17.5 (1931), p. 315.
[44] Alex Krizhevsky, Ilya Sutskever, and Geoffrey E Hinton. “Imagenet classification with deep
convolutional neural networks”. In: Advances in neural information processing systems. 2012,
pp. 1097–1105.
[45] Guokun Lai, Wei-Cheng Chang, Yiming Yang, and Hanxiao Liu. “Modeling long-and short-term
temporal patterns with deep neural networks”. In: The 41st International ACM SIGIR Conference on
Research & Development in Information Retrieval. 2018, pp. 95–104.
[46] Rongjie Lai, Jiang Liang, and Hongkai Zhao. “A LOCAL MESH METHOD FOR SOLVING PDES ON
POINT CLOUDS.” In: Inverse Problems & Imaging 7.3 (2013).
[47] Fuxian Li, Jie Feng, Huan Yan, Guangyin Jin, Depeng Jin, and Yong Li. Dynamic Graph
Convolutional Recurrent Network for Traffic Prediction: Benchmark and Solution. 2021. arXiv:
2104.14917 [cs.LG].
[48] Junnan Li, Pan Zhou, Caiming Xiong, and Steven Hoi. “Prototypical Contrastive Learning of
Unsupervised Representations”. In: International Conference on Learning Representations. 2021.
[49] Max Guangyu Li, Bo Jiang, Hao Zhu, Zhengping Che, and Yan Liu. “Generative Attention Networks
for Multi-Agent Behavioral Modeling.” In: AAAI. 2020.
106
[50] Tian Li, Shengyuan Hu, Ahmad Beirami, and Virginia Smith. “Ditto: Fair and robust federated
learning through personalization”. In: International Conference on Machine Learning. PMLR. 2021,
pp. 6357–6368.
[51] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia Smith.
“Federated optimization in heterogeneous networks”. In: Proceedings of the 3rd MLSys Conference.
2020.
[52] Tian Li, Anit Kumar Sahu, Manzil Zaheer, Maziar Sanjabi, Ameet Talwalkar, and Virginia Smith.
“Federated optimization in heterogeneous networks”. In: Proceedings of Machine Learning and
Systems 2 (2020), pp. 429–450.
[53] Xiang Li, Kaixuan Huang, Wenhao Yang, Shusen Wang, and Zhihua Zhang. “On the Convergence
of FedAvg on Non-IID Data”. In: International Conference on Learning Representations. 2020.
[54] Yaguang Li, Rose Yu, Cyrus Shahabi, and Yan Liu. “Diffusion Convolutional Recurrent Neural
Network: Data-Driven Traffic Forecasting”. In: International Conference on Learning Representations
(ICLR ’18). 2018.
[55] Yaguang Li, Rose Yu, Cyrus Shahabi, and Yan Liu. “Diffusion Convolutional Recurrent Neural
Network: Data-Driven Traffic Forecasting”. In: International Conference on Learning Representations.
2018. url: https://openreview.net/forum?id=SJiHXGWAZ.
[56] Yunzhu Li, Hao He, Jiajun Wu, Dina Katabi, and Antonio Torralba. “Learning Compositional
Koopman Operators for Model-Based Control”. In: International Conference on Learning
Representations. 2020. url: https://openreview.net/forum?id=H1ldzA4tPr.
[57] Paul Pu Liang, Terrance Liu, Liu Ziyin, Ruslan Salakhutdinov, and Louis-Philippe Morency. “Think
locally, act globally: Federated learning with local and global representations”. In: arXiv preprint
arXiv:2001.01523 (2020).
[58] Lek-Heng Lim. “Hodge Laplacians on graphs”. In: arXiv preprint arXiv:1507.05379 (2015).
[59] Ziyu Liu, Hongwen Zhang, Zhenghao Chen, Zhiyong Wang, and Wanli Ouyang. “Disentangling
and Unifying Graph Convolutions for Skeleton-Based Action Recognition”. In: Proceedings of the
IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020, pp. 143–152.
[60] Zichao Long, Yiping Lu, Xianzhong Ma, and Bin Dong. “Pde-net: Learning pdes from data”. In:
International Conference on Machine Learning (2018).
[61] Chuanjiang Luo, Issam Safa, and Yusu Wang. “Approximating gradients for meshes and point
clouds via diffusion metric”. In: Computer Graphics Forum. Vol. 28. Wiley Online Library. 2009,
pp. 1497–1508.
[62] Bethany Lusch, J Nathan Kutz, and Steven L Brunton. “Deep learning for universal linear
embeddings of nonlinear dynamics”. In: Nature communications 9.1 (2018), pp. 1–10.
107
[63] Michael Lutter, Christian Ritter, and Jan Peters. “Deep Lagrangian Networks: Using Physics as
Model Prior for Deep Learning”. In: International Conference on Learning Representations. 2019. url:
https://openreview.net/forum?id=BklHpjCqKm.
[64] Andrey Malinin, Neil Band, German Chesnokov, Yarin Gal, Mark JF Gales, Alexey Noskov,
Andrey Ploskonosov, Liudmila Prokhorenkova, Ivan Provilkov, Vatsal Raina, et al. “Shifts: A dataset
of real distributional shift across multiple large-scale tasks”. In: arXiv preprint arXiv:2107.07455
(2021).
[65] Othmane Marfoq, Giovanni Neglia, Aurélien Bellet, Laetitia Kameni, and Richard Vidal. “Federated
multi-task learning under a mixture of distributions”. In: Advances in Neural Information Processing
Systems 34 (2021).
[66] Brendan McMahan, Eider Moore, Daniel Ramage, Seth Hampson, and Blaise Aguera y Arcas.
“Communication-efficient learning of deep networks from decentralized data”. In: Artificial
Intelligence and Statistics. PMLR. 2017, pp. 1273–1282.
[67] Guangxu Mei, Ziyu Guo, Shijun Liu, and Li Pan. “SGNN: A Graph Neural Network Based Federated
Learning Approach by Hiding Structure”. In: 2019 IEEE International Conference on Big Data (Big
Data). IEEE. 2019, pp. 2560–2568.
[68] Chuizheng Meng, Hao Niu, Guillaume Habault, Roberto Legaspi, Shinya Wada, Chihiro Ono, and
Yan Liu. “Physics-Informed Long-Sequence Forecasting From Multi-Resolution Spatiotemporal
Data”. In: Proceedings of the Thirty-First International Joint Conference on Artificial Intelligence. 2022.
[69] Chuizheng Meng, Sirisha Rambhatla, and Yan Liu. “Cross-node federated graph neural network for
spatio-temporal data modeling”. In: Proceedings of the 27th ACM SIGKDD Conference on Knowledge
Discovery & Data Mining. 2021, pp. 1202–1211.
[70] Arvind T Mohan and Datta V Gaitonde. “A deep learning based approach to reduced order
modeling for turbulent flow control using LSTM neural networks”. In: arXiv preprint
arXiv:1804.09269 (2018).
[71] NREL. https://www.nrel.gov/grid/solar-power-data.html. Accessed: 2021-10-01. 2021.
[72] NYCTLC. https://www1.nyc.gov/site/tlc/about/tlc-trip-record-data.page. Accessed: 2021-10-01.
2021.
[73] Aaron van den Oord, Sander Dieleman, Heiga Zen, Karen Simonyan, Oriol Vinyals, Alex Graves,
Nal Kalchbrenner, Andrew Senior, and Koray Kavukcuoglu. “Wavenet: A generative model for raw
audio”. In: arXiv preprint arXiv:1609.03499 (2016).
[74] Aaron van den Oord, Yazhe Li, and Oriol Vinyals. “Representation learning with contrastive
predictive coding”. In: arXiv preprint arXiv:1807.03748 (2018).
[75] Tom J Pollard, Alistair EW Johnson, Jesse D Raffa, Leo A Celi, Roger G Mark, and Omar Badawi.
“The eICU Collaborative Research Database, a freely available multi-center database for critical care
research”. In: Scientific data 5.1 (2018), pp. 1–13.
108
[76] Maziar Raissi, Paris Perdikaris, and George Em Karniadakis. “Physics Informed Deep Learning (Part
I): Data-driven Solutions of Nonlinear Partial Differential Equations”. In: arXiv preprint
arXiv:1711.10561 (2017).
[77] Maziar Raissi, Paris Perdikaris, and George Em Karniadakis. “Physics Informed Deep Learning (Part
II): Data-driven Discovery of Nonlinear Partial Differential Equations”. In: arXiv preprint
arXiv:1711.10566 (2017).
[78] Stephan Rasp, Peter D Dueben, Sebastian Scher, Jonathan A Weyn, Soukayna Mouatadid, and
Nils Thuerey. “WeatherBench: a benchmark data set for data-driven weather forecasting”. In:
Journal of Advances in Modeling Earth Systems 12.11 (2020), e2020MS002203.
[79] Yulia Rubanova, Ricky T. Q. Chen, and David K Duvenaud. “Latent Ordinary Differential Equations
for Irregularly-Sampled Time Series”. In: Advances in Neural Information Processing Systems. Vol. 32.
2019, pp. 5320–5330.
[80] Lars Ruthotto and Eldad Haber. “Deep neural networks motivated by partial differential equations”.
In: arXiv preprint arXiv:1804.04272 (2018).
[81] Sina Sajadmanesh and Daniel Gatica-Perez. “When Differential Privacy Meets Graph Neural
Networks”. In: arXiv preprint arXiv:2006.05535 (2020).
[82] Alvaro Sanchez-Gonzalez, Nicolas Heess, Jost Tobias Springenberg, Josh Merel, Martin Riedmiller,
Raia Hadsell, and Peter Battaglia. “Graph networks as learnable physics engines for inference and
control”. In: International Conference on Machine Learning (2018).
[83] Adam Santoro, David Raposo, David G Barrett, Mateusz Malinowski, Razvan Pascanu,
Peter Battaglia, and Timothy Lillicrap. “A simple neural network module for relational reasoning”.
In: Advances in neural information processing systems. 2017, pp. 4967–4976.
[84] Felix Sattler, Klaus-Robert Müller, and Wojciech Samek. “Clustered federated learning:
Model-agnostic distributed multitask optimization under privacy constraints”. In: IEEE transactions
on neural networks and learning systems (2020).
[85] Peter J Schmid. “Dynamic mode decomposition of numerical and experimental data”. In: Journal of
fluid mechanics 656 (2010), pp. 5–28.
[86] Sungyong Seo and Yan Liu. “Differentiable Physics-informed Graph Networks”. In: arXiv preprint
arXiv:1902.02950 (2019).
[87] Sungyong Seo, Chuizheng Meng, and Yan Liu. “Physics-aware Difference Graph Networks for
Sparsely-Observed Dynamics”. In: International Conference on Learning Representations. 2019.
[88] Sungyong Seo*, Chuizheng Meng*, and Yan Liu (*contributed equally). “Physics-aware Difference
Graph Networks for Sparsely-Observed Dynamics”. In: International Conference on Learning
Representations. 2020. url: https://openreview.net/forum?id=r1gelyrtwH.
[89] Jonathan Richard Shewchuk. “What is a good linear finite element? interpolation, conditioning,
anisotropy, and quality measures (preprint)”. In: University of California at Berkeley 73 (2002), p. 137.
109
[90] Wenzhe Shi, Jose Caballero, Ferenc Huszár, Johannes Totz, Andrew P Aitken, Rob Bishop,
Daniel Rueckert, and Zehan Wang. “Real-time single image and video super-resolution using an
efficient sub-pixel convolutional neural network”. In: Proceedings of the IEEE conference on computer
vision and pattern recognition. 2016, pp. 1874–1883.
[91] Abhishek Singh, Praneeth Vepakomma, Otkrist Gupta, and Ramesh Raskar. “Detailed comparison
of communication efficiency of split learning and federated learning”. In: arXiv preprint
arXiv:1909.09145 (2019).
[92] Virginia Smith, Chao-Kai Chiang, Maziar Sanjabi, and Ameet S Talwalkar. “Federated multi-task
learning”. In: Advances in Neural Information Processing Systems. 2017, pp. 4424–4434.
[93] Toyotaro Suzumura, Yi Zhou, Natahalie Barcardo, Guangnan Ye, Keith Houck, Ryo Kawahara,
Ali Anwar, Lucia Larise Stavarache, Daniel Klyashtorny, Heiko Ludwig, et al. “Towards Federated
Graph Learning for Collaborative Financial Crimes Detection”. In: arXiv preprint arXiv:1909.12946
(2019).
[94] Canh T Dinh, Nguyen Tran, and Tuan Dung Nguyen. “Personalized Federated Learning with
Moreau Envelopes”. In: Advances in Neural Information Processing Systems 33 (2020).
[95] Yue Tan, Guodong Long, Jie Ma, Lu Liu, Tianyi Zhou, and Jing Jiang. “Federated Learning from
Pre-Trained Models: A Contrastive Learning Approach”. In: Advances in Neural Information
Processing Systems (NeurIPS). 2022.
[96] Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
Łukasz Kaiser, and Illia Polosukhin. “Attention is All you Need”. In: Advances in Neural Information
Processing Systems. Ed. by I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus,
S. Vishwanathan, and R. Garnett. Vol. 30. Curran Associates, Inc., 2017, pp. 5998–6008. url:
https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf.
[97] Rui Wang, Karthik Kashinath, Mustafa Mustafa, Adrian Albert, and Rose Yu. “Towards
physics-informed deep learning for turbulent flow prediction”. In: arXiv preprint arXiv:1911.08655
(2019).
[98] Rui Wang, Karthik Kashinath, Mustafa Mustafa, Adrian Albert, and Rose Yu. “Towards
physics-informed deep learning for turbulent flow prediction”. In: Proceedings of the 26th ACM
SIGKDD International Conference on Knowledge Discovery & Data Mining. 2020, pp. 1457–1466.
[99] Matthew O Williams, Ioannis G Kevrekidis, and Clarence W Rowley. “A data–driven
approximation of the koopman operator: Extending dynamic mode decomposition”. In: Journal of
Nonlinear Science 25.6 (2015), pp. 1307–1346.
[100] Yonghui Wu, Mike Schuster, Zhifeng Chen, Quoc V Le, Mohammad Norouzi, Wolfgang Macherey,
Maxim Krikun, Yuan Cao, Qin Gao, Klaus Macherey, et al. “Google’s neural machine translation
system: Bridging the gap between human and machine translation”. In: arXiv preprint
arXiv:1609.08144 (2016).
110
[101] Zonghan Wu, Shirui Pan, Guodong Long, Jing Jiang, Xiaojun Chang, and Chengqi Zhang.
“Connecting the dots: Multivariate time series forecasting with graph neural networks”. In:
Proceedings of the 26th ACM SIGKDD International Conference on Knowledge Discovery & Data
Mining. 2020, pp. 753–763.
[102] Zonghan Wu, Shirui Pan, Guodong Long, Jing Jiang, and Chengqi Zhang. “Graph WaveNet for
Deep Spatial-Temporal Graph Modeling.” In: IJCAI. 2019.
[103] Keyulu Xu, Jingling Li, Mozhi Zhang, Simon S Du, Ken-ichi Kawarabayashi, and Stefanie Jegelka.
“What Can Neural Networks Reason About?” In: International Conference on Learning
Representations (ICLR). 2019.
[104] Sijie Yan, Yuanjun Xiong, and Dahua Lin. “Spatial temporal graph convolutional networks for
skeleton-based action recognition”. In: AAAI. 2018.
[105] Huaxiu Yao, Xianfeng Tang, Hua Wei, Guanjie Zheng, and Zhenhui Li. “Revisiting spatial-temporal
similarity: A deep learning framework for traffic prediction”. In: Proceedings of the AAAI conference
on artificial intelligence. Vol. 33. 2019, pp. 5668–5675.
[106] Rex Ying, Ruining He, Kaifeng Chen, Pong Eksombatchai, William L Hamilton, and Jure Leskovec.
“Graph convolutional neural networks for web-scale recommender systems”. In: Proceedings of the
24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. 2018,
pp. 974–983.
[107] Jiaxuan You, Rex Ying, and Jure Leskovec. “Position-aware graph neural networks”. In: Proceedings
of the 36th International Conference on Machine Learning. 2019.
[108] Bing Yu, Haoteng Yin, and Zhanxing Zhu. “Spatio-temporal Graph Convolutional Networks: A
Deep Learning Framework for Traffic Forecasting”. In: Proceedings of the 27th International Joint
Conference on Artificial Intelligence (IJCAI). 2018.
[109] Junbo Zhang, Yu Zheng, and Dekang Qi. “Deep spatio-temporal residual networks for citywide
crowd flows prediction”. In: Thirty-First AAAI Conference on Artificial Intelligence. 2017.
[110] Michael Zhang, Karan Sapra, Sanja Fidler, Serena Yeung, and Jose M. Alvarez. “Personalized
Federated Learning with First Order Model Optimization”. In: International Conference on Learning
Representations. 2021. url: https://openreview.net/forum?id=ehJqJQk9cw.
[111] Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, and Vikas Chandra. “Federated
learning with non-iid data”. In: arXiv preprint arXiv:1806.00582 (2018).
[112] Chuanpan Zheng, Xiaoliang Fan, Cheng Wang, and Jianzhong Qi. “Gman: A graph multi-attention
network for traffic prediction”. In: Proceedings of the AAAI Conference on Artificial Intelligence.
Vol. 34. 2020, pp. 1234–1241.
[113] Haoyi Zhou, Shanghang Zhang, Jieqi Peng, Shuai Zhang, Jianxin Li, Hui Xiong, and Wancai Zhang.
“Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting”. In: The
Thirty-Fifth AAAI Conference on Artificial Intelligence, AAAI 2021. AAAI Press, 2021, online.
111
[114] Jun Zhou, Chaochao Chen, Longfei Zheng, Xiaolin Zheng, Bingzhe Wu, Ziqi Liu, and Li Wang.
“Privacy-Preserving Graph Neural Network for Node Classification”. In: arXiv preprint
arXiv:2005.11903 (2020).
[115] Kaiyang Zhou, Yongxin Yang, Yu Qiao, and Tao Xiang. “Domain adaptive ensemble learning”. In:
IEEE Transactions on Image Processing 30 (2021), pp. 8008–8018.
[116] Zhengyang Zhou, Yang Wang, Xike Xie, Lianliang Chen, and Hengchang Liu. “RiskOracle: A
Minute-Level Citywide Traffic Accident Forecasting Framework”. In: Proceedings of the AAAI
Conference on Artificial Intelligence. Vol. 34. 2020, pp. 1258–1265.
[117] Jiawei Zhuang, Dmitrii Kochkov, Yohai Bar-Sinai, Michael P Brenner, and Stephan Hoyer. “Learned
discretizations for passive scalar advection in a 2-D turbulent flow”. In: arXiv preprint
arXiv:2004.05477 (2020).
112
Abstract (if available)
Abstract
With the great success of data-driven machine learning methods, concerns with the trustworthiness of machine learning models have been emerging in recent years. From the modeling perspective, the lack of trustworthiness amplifies the effect of insufficient training data. Purely data-driven models without constraints from domain knowledge tend to suffer from over-fitting and losing the generalizability of unseen data. Meanwhile, concerns with data privacy further obstruct the availability of data from more providers. On the application side, the absence of trustworthiness hinders the application of data-driven methods in domains such as spatiotemporal forecasting, which involves data from critical applications including traffic, climate, and energy. My dissertation constructs spatiotemporal prediction models with enhanced trustworthiness from both the model and the data aspects. For model trustworthiness, the dissertation focuses on improving the generalizability of models via the integration of physics knowledge. For data trustworthiness, the proposal proposes a spatiotemporal forecasting model in the federated learning context, where data in a network of nodes is generated locally on each node and remains decentralized. Furthermore, the dissertation amalgamates the trustworthiness from both aspects and combines the generalizability of knowledge-informed models with the privacy preservation of federated learning for spatiotemporal modeling.
Linked assets
University of Southern California Dissertations and Theses
Conceptually similar
PDF
Physics-aware graph networks for spatiotemporal physical systems
PDF
Tensor learning for large-scale spatiotemporal analysis
PDF
Learning at the local level
PDF
Failure prediction for rod pump artificial lift systems
PDF
Towards combating coordinated manipulation to online public opinions on social media
PDF
Spatiotemporal prediction with deep learning on graphs
PDF
Latent space dynamics for interpretation, monitoring, and prediction in industrial systems
PDF
Transforming unstructured historical and geographic data into spatio-temporal knowledge graphs
PDF
Deep learning models for temporal data in health care
PDF
Data-driven multi-fidelity modeling for physical systems
PDF
Dynamic topology reconfiguration of Boltzmann machines on quantum annealers
PDF
Striking the balance: optimizing privacy, utility, and complexity in private machine learning
PDF
Learning controllable data generation for scalable model training
PDF
Generative foundation model assisted privacy-enhancing computing in human-centered machine intelligence
PDF
Coding centric approaches for efficient, scalable, and privacy-preserving machine learning in large-scale distributed systems
PDF
Practice-inspired trust models and mechanisms for differential privacy
PDF
Fair Machine Learning for Human Behavior Understanding
PDF
Scalable optimization for trustworthy AI: robust and fair machine learning
PDF
Interpretable machine learning models via feature interaction discovery
PDF
Artificial Decision Intelligence: integrating deep learning and combinatorial optimization
Asset Metadata
Creator
Meng, Chuizheng
(author)
Core Title
Trustworthy spatiotemporal prediction models
School
Viterbi School of Engineering
Degree
Doctor of Philosophy
Degree Program
Computer Science
Degree Conferral Date
2024-05
Publication Date
04/18/2024
Defense Date
04/17/2024
Publisher
Los Angeles, California
(original),
University of Southern California
(original),
University of Southern California. Libraries
(digital)
Tag
federated learning,OAI-PMH Harvest,physics-informed machine learning,spatiotemporal data,trustworthy machine learning
Format
theses
(aat)
Language
English
Contributor
Electronically uploaded by the author
(provenance)
Advisor
Liu, Yan (
committee chair
), Neiswanger, Willie (
committee member
), Oberai, Assad (
committee member
)
Creator Email
chuizhem@usc.edu,mengcz95thu@gmail.com
Permanent Link (DOI)
https://doi.org/10.25549/usctheses-oUC113880139
Unique identifier
UC113880139
Identifier
etd-MengChuizh-12835.pdf (filename)
Legacy Identifier
etd-MengChuizh-12835
Document Type
Dissertation
Format
theses (aat)
Rights
Meng, Chuizheng
Internet Media Type
application/pdf
Type
texts
Source
20240418-usctheses-batch-1142
(batch),
University of Southern California
(contributing entity),
University of Southern California Dissertations and Theses
(collection)
Access Conditions
The author retains rights to his/her dissertation, thesis or other graduate work according to U.S. copyright law. Electronic access is being provided by the USC Libraries in agreement with the author, as the original true and official version of the work, but does not grant the reader permission to use the work if the desired use is covered by copyright. It is the author, as rights holder, who must provide use permission if such use is covered by copyright.
Repository Name
University of Southern California Digital Library
Repository Location
USC Digital Library, University of Southern California, University Park Campus MC 2810, 3434 South Grand Avenue, 2nd Floor, Los Angeles, California 90089-2810, USA
Repository Email
cisadmin@lib.usc.edu
Tags
federated learning
physics-informed machine learning
spatiotemporal data
trustworthy machine learning