关于联邦学习构造Non-IID数据集的记录
简单来说,Non-IID就是指每个设备中的数据分布不能代表全局数据分布。本篇简单记录一下自己在研究联邦学习过程中对Non-IID数据的思考和处理
Dateset Shift(数据集偏移)
联邦学习中客户端之间数据的Non-IID分布,和在做机器学习任务时可能遇到的训练集与测试集分布不一致是一个道理
训练集和测试集分布不一致被称作数据集偏移(Dataset Shift),有3种类型:
协变量偏移(Covariate Shift): 独立变量的偏移,指训练集和测试集的输入服从不同分布,但背后是服从同一个函数关系,即: \[ P_{train}(y|x) = P_{test}(y|x)\\ P_{train}(x) \neq P_{test}(x) \]
先验概率偏移(Prior Probability Shift): 目标变量的偏移,即: \[ P_{train}(x|y) = P_{test}(x|y)\\ P_{train}(y) \neq P_{test}(y) \]
概念偏移(Concept Shift): 独立变量和目标变量之间关系的偏移,即: \[ P_{train}(y|x) \neq P_{test}(y|x)\ and\ P_{train}(x) = P_{test}(x)\ in\ X\rightarrow Y\ problems \\ P_{train}(x|y) \neq P_{test}(x|y)\ and\ P_{train}(y) = P_{test}(y)\ in\ Y\rightarrow X\ problems \]
联邦学习中客户端数据Non-IID分布的五种类型
Feature distribution skew (convariate shift)
与数据集偏移中的协变量偏移同理
以MNIST数据集为例,不同的人写同一个数字,写法不同(即不同客户端的\(P_i(x)\)分布不同),但是不同客户端用这个特征
x
预测得到该标签y
的概率是相近的,即\(P(y|x)\)分布相同Label distribution skew (prior probability shift)
与数据集偏移中的先验概率偏移同理
以MNIST数据集为例,不同的客户端内,各个标签所占的比例不是平均的,如A有50%的标签
1
,B有50%的标签2
(即不同客户端的\(P_i(y)\)分布不同),但是当标签给定时,A和B中对应的特征都大概率相似,即\(P(x|y)\)分布相同Same label, different features (concept shift)
与数据集偏移中的概念偏移同理,\(P_{train}(x|y) \neq P_{test}(x|y)\ and\ P_{train}(y) = P_{test}(y)\ \)
通俗理解:同样是车(标签
y
),迈凯伦和花冠完全是两回事(特征x
)Same features, different label (concept shift)
与数据集偏移中的概念偏移同理,
\(P_{train}(y|x) \neq P_{test}(y|x)\ and\ P_{train}(x) = P_{test}(x)\)
通俗理解:同样是要买交通工具(特征
x
),A想买跑车,B想买机车(标签y
)Quantity skew or unbalancedness
即不同客户端持有的数据量差异较大
实践中遇到的具体实现方式
A. 第二类Non-IID的第一种分配方式
以MNIST为例,共10个客户端。分配后的数据表示方式:建立一个字典,对于每个客户i
,i
作为键,数据的索引存在一个列表中作为值。如:
1 | { |
数据分配思路:根据客户ID的奇偶性分为两类(当然也可以分为更多类,此处以两类为例)。同时将数据集根据标签分类,奇数、偶数各为一类。奇数数据在0~4客户端中平均分配,偶数数据在5~9客户端中平均分配。
当然,本例只是分为了2组,实际上也可以分为5组、10组。
代码及注释如下:
1 | def mnist_noniid(dataset, num_users): |
B. 第二类Non-IID的第二种分配方式
数据分配思路:将排序后的数据分为2 * num_users
个碎片,每个客户随机取2个碎片,那么该客户拥有的数据就只包含这2个碎片里的标签(比如说60000个数据分成200个碎片,每个碎片300个数据,假设某个客户取到了前面的600个数据,那么他取到的数据含有的标签就只可能包含0,1,2,但是不可能包含9),从而构成了Non-IID
代码及注释如下:
1 | def mnist_noniid(dataset, num_users): |
附:Python中[m, n]的用法
x[m, n]
是通过numpy引用数组或矩阵中的某一段数据集的一种写法。m
表示第几维,n
表示第几个数据。
在一个2维矩阵中,通俗来理解,x[m, n]
就是第m
行第n
列的数据,例如:
1 | x = np.array([ |
那么x[1, 2]
得到的数据就是6
其他常用的用法:
1 | x[n, :] # 取第n维的全部数据 |
参考
https://blog.csdn.net/qq_43827595/article/details/120661931