As the development of the Internet of Things (IoT) continues, Federated Learning (FL) is gaining popularity as a distributed machine learning framework that does not compromise the data privacy of each participant. However, the data held by enterprises and factories in the IoT often have different distribution properties (Non-IID), leading to poor results in their federated learning.
1. Introduction
With the growth of the Internet of Things and advances in Big Data-driven artificial intelligence, the network’s data are increasingly created by geographically distributed enterprise endpoints and IoT devices. The IoT in the context of Big Data is growing rapidly in the industrial sector. However, centralized aggregation of industrial Big Data to cloud servers leads to unaffordable transmission overheads and also violates the data privacy of each enterprise or client, which results in distributed databases consisting of multiple “data islands”. In light of the challenges posed by “data islands” in the development and application of the Internet of Things, Federated Learning (FL)
[1] was first proposed in 2016 for collaborative learning with privacy constraints. It has been widely used in the IoT tasks such as smart cities
[2], healthcare
[3,4,5][3][4][5] and financial security
[6]. Meanwhile, in the industrial Internet of Things, the use of federated learning to enable collaborative training of all parties while ensuring sensitive enterprise data is gradually becoming mainstream.
Although federated learning does not require centralized data aggregation to the cloud, however, there will be skewed distribution of data across enterprises during practical applications, which will lead to degradation of FL performance. For example, the data collected for mobile terminal input methods have different distributions for people with different operating habits. There are also many differences in the distribution of sensor acquisition data used by different plants to classify equipment faults or detect quality defects, both in terms of sensor type acquisition differences and fault type distribution differences. Wearable IoT sensor data for monitoring patient vital signs in the medical field are also heavily used for artificial intelligence learning to enable online diagnostics such as expert systems. However, the data collected by these edge IoT devices still suffer from inconsistencies in the distribution of features and labels. In summary, this huge challenge called Non-IID hinders the application of federated learning in IoT in the context of Big Data, and this woresearkch aims to propose a generalized method to address this challenge in federal learning for IoT applications.
Similar to the problem in continuous learning
[7], this variation among distributions causes each client to forget the global knowledge during their local updates, which in turn severely affects the performance and convergence in federated learning
[8,9][8][9].
Figure 1 illustrates the catastrophic forgetting in continuous learning and federated learning. This phenomenon is referred to as “Client-drift” in
[10]. FedProx
[11] constrains the local updates by adding a regularization term to the local objective function to regulate its update direction towards the global objective. In recent years, knowledge distillation
[12] is widely used for transferring knowledge between models to serve the purpose of compressing model size and improving model accuracy. In order to get a more robust global model, ref.
[13] combines federated learning with knowledge distillation to fine-tune the global model after aggregation with an additional public dataset. Ref.
[14] applies large model self-distillation on the server side to better maintain global knowledge. All these methods require an auxiliary public dataset for knowledge distillation, and FedGEMS even requires homogenous datasets for auxiliary data. Considering the existence of the forgetting phenomenon in federated learning, which is similar to continuous learning, especially in the case of data heterogeneity, this
woresear
kch attempts to introduce knowledge distillation in the local training phase, using collaborative distillation of global and local models to retain each other’s knowledge. Inspired by Relational KD
[15], the high-dimensional relational knowledge naturally contained in a global model is distilled after each aggregation to achieve better performance in combination with single-sample knowledge. Relational knowledge distillation considers that “relationships” among knowledge are more representative of the teacher’s “knowledge” than separate representations. Similar to the view in linguistic structuralism
[16], which focuses on structural relationships in symbolic systems, primary information is often located in structural relationships in the data embedding space rather than existing independently. Meanwhile, in order to weigh the effect of constraints on single-sample knowledge versus relational knowledge, this
woresear
kch introduces an adaptive coefficient module to dynamically adjust its constraints.
Figure 1. Catastrophic forgetting. (a) Forgetting in continuous learning; (b) forgetting in federated learning.
Inspired by above considerations,
wresearche
rs propose a relational adaptive distillation paradigm called Relational Adaptive Distillation for Heterogeneous Federated Learning, abbreviated as FedRAD. The aggregated global model is downloaded by the selected clients during each round of communication and collaboratively distilled
[17] with their own local models, transmitting both single-sample knowledge based on classical knowledge distillation and relational knowledge based on high-dimensional structural representations. This method can fully exploit the potential of knowledge distillation to exploit various types of knowledge in distributed data, which helps to motivate local models to learn higher dimensional knowledge representations from global models and minimize the forgetting phenomenon of local training in data heterogeneous scenarios. To better weigh the penalty focus of single-sample knowledge versus relational knowledge,
wresearche
rs further propose an entropy-wise adaptive weight (EWAW) strategy to help local models adaptively control the impact of distillations based on the global model’s predictions on each data batch to prevent excessive transfer of negative knowledge. When the prediction of the global model is plausible, the local model learns the single-sample knowledge and relational knowledge in a balanced way. Otherwise, the local model focuses more on relational knowledge.
2. Federated Learning
With the emergence of various data privacy protection requirements, secure multi-party computing
[18,19][18][19] is commonly used in the past as the major method to resolve the conflict between data confidentiality and sharing in the IoT. However, its huge error accumulation and high computational cost in deep learning applications make it difficult for it to be competent for deep learning scenarios. In contrast, federated learning is widely used in deep learning as an emerging distributed learning paradigm. FedAvg
[1] is a traditional classical federated learning paradigm, where the parameters or gradients of all local models are aggregated by the server to form a global model after some local updates are performed by each client, and the aggregation weights are proportional to the local data size. A key challenge of this classical paradigm is that the clients’ data are usually non-identically distributed (Non-IID). Many works have attempted to solve the Non-IID problem by improving the server aggregation phase or the local training phase. Refs.
[20,21][20][21] start from a clustering perspective by assuming that there are differences in the similarity of data distributions among different clients, assigning similar clients into a cluster and implementing global model training within each cluster to reduce the impact of non-identical distributions. These methods are premised on the assumption of similarity in the distribution of client data, perform poorly in scenarios where the distribution of client data varies too much and fail to truly address the forgetting phenomenon that occurs during local training. Refs.
[11,22,23][11][22][23] aim to improve the local training phase by adding a regularization term to local model as a constraint to adjust the deviation between local and global models and reduce the client drift phenomenon. Another technical route is to improve the server-side aggregation phase
[24,25][24][25].
Unlike this route, this paper aims to preserve the global model knowledge in the local training phase, which belongs to improve the local training method.
3. Knowledge Distillation in FL
Knowledge distillation can transfer knowledge from large teacher models to small student models and is widely used for model compression
[26,27][26][27] and collaborative learning among students to improve performance
[17,28][17][28]. In order to address data heterogeneity, knowledge distillation applied to federated learning has proven to be an effective approach. Many works take aim at ensemble distillation, i.e., transferring knowledge to a global model as the student by aggregating client knowledge as the teacher. Ref.
[29] use transfer learning on public datasets and ensemble distillation on the client side to improve model performance and reduce communication consumption. This is achieved by accomplishing knowledge transfer while exchanging only model predictions rather than model parameters. Ref.
[13] combine federated learning with knowledge distillation to fine-tune the global model after aggregation with an additional public dataset in order to get a more robust global model. Ref.
[14] improve on FedMD by holding a large model on the server side for self-distillation to better preserve global knowledge, which also avoids the forgetting of knowledge by the model. However, all these FL methods above require an additional public dataset similar to the client’s private datasets for knowledge distillation, and these carefully prepared public datasets are not always available. Some recent works attempt to extract knowledge without using additional public datasets: Ref.
[30] aggregates and averages the logits in different classes transferred by each client on the server side for distribution as distillation knowledge in order to avoid reliance on public datasets. However, this method directly averages all logit of the same class, which tends to blur the knowledge across clients. Ref.
[31] combine split learning to split the local model into a feature extraction network and a classification network, which use the intermediate features as inputs transmitted to the server-side classification network for knowledge distillation. This approach reduces the high computing power requirements for edge computing but performs poorly in the data heterogeneity scenario. Ref.
[32] propose that the triangular upper bound of the federated learning objective function should be optimized especially in Non-IID scenarios, where both local training and knowledge distillation are used to lower the upper bound to improve performance. Refs.
[33,34,35][33][34][35] use a pre-trained generator to generate pseudo data as a public dataset to assist in training, e.g., FedFTG trains a generator against the global model to generate difficult pseudo data in order to assist in training the global model to avoid forgetting. However, these methods of using a generator to add data require extremely high-quality generated samples, and often need a large computational cost to obtain high-quality samples. Ref.
[36] perform knowledge distillation in the local training phase by broadcasting local data representations and the corresponding soft predictions, which is named “hyper-knowledge”. This method has some similarity to ours, without the need for generative model and public datasets, but is more concerned with balancing the performance of the local and global models.