Per_FeAVG源码分析——根目录下:
KarhouTam的Per_FedAVG.源码链接:请使用到的点个star
函数:get_args()
功能:用于加载参数:使用ArgumentParser()输入了联邦参数 ,模型参数 ,其他参数 三类参数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 import torchimport randomimport numpy as npfrom typing import Iterator, Tuple , Union from argparse import ArgumentParserdef get_args (): parser = ArgumentParser() parser.add_argument("--alpha" , type =float , default=1e-2 ) parser.add_argument("--beta" , type =float , default=1e-3 ) parser.add_argument("--global_epochs" , type =int , default=200 ) parser.add_argument("--local_epochs" , type =int , default=4 ) parser.add_argument( "--pers_epochs" , type =int , default=1 , help ="Indicate how many data batches would be used for personalization. Negatives means that equal to train phase." , ) parser.add_argument( "--hf" , type =int , default=0 , help ="0 for performing Per-FedAvg(FO), others for Per-FedAvg(HF)" , ) parser.add_argument("--batch_size" , type =int , default=40 ) parser.add_argument( "--valset_ratio" , type =float , default=0.1 , help ="Proportion of val set in the entire client local dataset" , ) parser.add_argument( "--dataset" , type =str , choices=["mnist" , "cifar" ], default="mnist" ) parser.add_argument("--client_num_per_round" , type =int , default=10 ) parser.add_argument("--seed" , type =int , default=17 ) parser.add_argument( "--gpu" , type =int , default=1 , help ="Non-zero value for using gpu, 0 for using cpu" , ) parser.add_argument( "--eval_while_training" , type =int , default=1 , help ="Non-zero value for performing local evaluation before and after local training" , ) parser.add_argument("--log" , type =int , default=1 ) return parser.parse_args()
函数:eval()
功能:用于在PyTorch中评估给定模型的性能。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 @torch.no_grad() def eval ( model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, criterion: Union [torch.nn.MSELoss, torch.nn.CrossEntropyLoss], device=torch.device("cpu" ), ) -> Tuple [torch.Tensor, torch.Tensor]: model.eval () total_loss = 0 num_samples = 0 acc = 0 for x, y in dataloader: x, y = x.to(device), y.to(device) logit = model(x) total_loss += criterion(logit, y) pred = torch.softmax(logit, -1 ).argmax(-1 ) acc += torch.eq(pred, y).int ().sum () num_samples += y.size(-1 ) model.train() return total_loss, acc / num_samples
函数:fix_random_seed(seed: int)
作用:设置随机种子以确保结果的可复现性
1 2 3 4 5 6 7 8 def fix_random_seed (seed: int ): torch.cuda.empty_cache() torch.manual_seed(seed) torch.cuda.manual_seed(seed) random.seed(seed) np.random.seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True
函数:elu(nn.Module)
作用:它实现了指数线性单元(Exponential Linear Unit, ELU)激活函数。
1 2 3 4 5 6 7 class elu (nn.Module): def __init__ (self ) -> None : super (elu, self).__init__() def forward (self, x ): return torch.where(x >= 0 , x, 0.2 * (torch.exp(x) - 1 ))
类:linear(nn.Module)
作用:__init__方法用于初始化权重(w)和偏置(b),而forward方法定义了数据通过网络层的前向传播过程。
1 2 3 4 5 6 7 8 9 10 11 class linear (nn.Module): def __init__ (self, in_c, out_c ) -> None : super (linear, self).__init__() self.w = nn.Parameter( torch.randn(out_c, in_c) * torch.sqrt(torch.tensor(2 / in_c)) ) self.b = nn.Parameter(torch.randn(out_c)) def forward (self, x ): return F.linear(x, self.w, self.b)
类:MLP_MNIST
作用:构建不同类型的神经网络模型,分别是MLP(多层感知机)、CNNMnist(用于MNIST手写数字数据集的卷积神经网络)和CNNCifar(用于CIFAR-10数据集的卷积神经网络)。实现了神经网络的前向传播过程,并用于分类任务。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 class MLP_MNIST (nn.Module): def __init__ (self ) -> None : super (MLP_MNIST, self).__init__() self.fc1 = linear(28 * 28 , 80 ) self.fc2 = linear(80 , 60 ) self.fc3 = linear(60 , 10 ) self.flatten = nn.Flatten() self.activation = elu() def forward (self, x ): x = self.flatten(x) x = self.fc1(x) x = self.activation(x) x = self.fc2(x) x = self.activation(x) x = self.fc3(x) x = self.activation(x) return x class MLP_CIFAR10 (nn.Module): def __init__ (self ) -> None : super (MLP_CIFAR10, self).__init__() self.fc1 = linear(32 * 32 * 3 , 80 ) self.fc2 = linear(80 , 60 ) self.fc3 = linear(60 , 10 ) self.flatten = nn.Flatten() self.activation = elu() def forward (self, x ): x = self.flatten(x) x = self.fc1(x) x = self.activation(x) x = self.fc2(x) x = self.activation(x) x = self.fc3(x) x = self.activation(x) return x
类:MLP_CIFAR10
作用:构建不同类型的神经网络模型,分别是MLP(多层感知机)、CNNMnist(用于MNIST手写数字数据集的卷积神经网络)和CNNCifar(用于CIFAR-10数据集的卷积神经网络)。实现了神经网络的前向传播过程,并用于分类任务。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 class MLP_CIFAR10 (nn.Module): def __init__ (self ) -> None : super (MLP_CIFAR10, self).__init__() self.fc1 = linear(32 * 32 * 3 , 80 ) self.fc2 = linear(80 , 60 ) self.fc3 = linear(60 , 10 ) self.flatten = nn.Flatten() self.activation = elu() def forward (self, x ): x = self.flatten(x) x = self.fc1(x) x = self.activation(x) x = self.fc2(x) x = self.activation(x) x = self.fc3(x) x = self.activation(x) return x
字典:MODEL_DICT
作用:关联键"mnist"和"cifar"到它们各自的多层感知机(MLP)模型类MLP_MNIST和MLP_CIFAR10
1 MODEL_DICT = {"mnist" : MLP_MNIST, "cifar" : MLP_CIFAR10}
函数:get_model(dataset, device)
作用:根据数据集名称从MODEL_DICT字典中获取相应的模型类,并实例化模型,然后将模型移动到指定的设备上(CPU或GPU)
1 2 def get_model (dataset, device ): return MODEL_DICT[dataset]().to(device)
函数:init ()
作用:类的初始化方法,用于配置和初始化类的实例变量。该方法接收多个参数,包括客户端ID、学习率参数(alpha和beta,可能是某种优化算法中的参数,如Momentum或Adam中的beta1和beta2)、全局模型、损失函数、批量大小、数据集名称、本地训练轮数、验证集比例、日志记录器和GPU设备ID。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 def __init__ ( self, client_id: int , alpha: float , beta: float , global_model: torch.nn.Module, criterion: Union [torch.nn.CrossEntropyLoss, torch.nn.MSELoss], batch_size: int , dataset: str , local_epochs: int , valset_ratio: float , logger: rich.console.Console, gpu: int , ): if gpu and torch.cuda.is_available(): self.device = torch.device("cuda" ) else : self.device = torch.device("cpu" ) self.logger = logger self.local_epochs = local_epochs self.criterion = criterion self.id = client_id self.model = deepcopy(global_model) self.alpha = alpha self.beta = beta self.trainloader, self.valloader = get_dataloader( dataset, client_id, batch_size, valset_ratio ) self.iter_trainloader = iter (self.trainloader)
函数:get_data_batch(self)
作用:用于从训练数据加载器中获取下一批数据,并处理StopIteration异常(当迭代器耗尽时触发)。当iter_trainloader中的数据被完全迭代一遍后,该方法会重新初始化迭代器并获取新的数据批次。
1 2 3 4 5 6 7 8 9 10 def get_data_batch (self ): try : x, y = next (self.iter_trainloader) except StopIteration: self.iter_trainloader = iter (self.trainloader) x, y = next (self.iter_trainloader) return x.to(self.device), y.to(self.device)
函数:train()
作用:是一个用于在本地客户端上训练模型的函数。该方法接收全局模型、一个布尔值hessian_free(用于指示是否使用Hessian-free优化)和一个布尔值eval_while_training(用于指示是否在训练前后评估模型性能)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 def train ( self, global_model: torch.nn.Module, hessian_free=False , eval_while_training=False , ): self.model.load_state_dict(global_model.state_dict()) if eval_while_training: loss_before, acc_before = utils.eval ( self.model, self.valloader, self.criterion, self.device ) self._train(hessian_free) if eval_while_training: loss_after, acc_after = utils.eval ( self.model, self.valloader, self.criterion, self.device ) self.logger.log( "client [{}] [red]loss: {:.4f} -> {:.4f} [blue]acc: {:.2f}% -> {:.2f}%" .format ( self.id , loss_before, loss_after, acc_before * 100.0 , acc_after * 100.0 , ) ) return SerializationTool.serialize_model(self.model)
函数:_train()
作用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 def _train (self, hessian_free=False ): if hessian_free: for _ in range (self.local_epochs): temp_model = deepcopy(self.model) data_batch_1 = self.get_data_batch() grads = self.compute_grad(temp_model, data_batch_1) for param, grad in zip (temp_model.parameters(), grads): param.data.sub_(self.alpha * grad) data_batch_2 = self.get_data_batch() grads_1st = self.compute_grad(temp_model, data_batch_2) data_batch_3 = self.get_data_batch() grads_2nd = self.compute_grad( self.model, data_batch_3, v=grads_1st, second_order_grads=True ) for param, grad1, grad2 in zip ( self.model.parameters(), grads_1st, grads_2nd ): param.data.sub_(self.beta * grad1 - self.beta * self.alpha * grad2) else : for _ in range (self.local_epochs): temp_model = deepcopy(self.model) data_batch_1 = self.get_data_batch() grads = self.compute_grad(temp_model, data_batch_1) for param, grad in zip (temp_model.parameters(), grads): param.data.sub_(self.alpha * grad) data_batch_2 = self.get_data_batch() grads = self.compute_grad(temp_model, data_batch_2) for param, grad in zip (self.model.parameters(), grads): param.data.sub_(self.beta * grad)
函数:compute_grad()
作用:根据给定的数据批次data_batch计算模型model的梯度。如果second_order_grads为True,它将计算二阶梯度(Hessian-vector积的一个近似),否则,它将计算标准的一阶梯度。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 def compute_grad ( self, model: torch.nn.Module, data_batch: Tuple [torch.Tensor, torch.Tensor], v: Union [Tuple [torch.Tensor, ...], None ] = None , second_order_grads=False , ): x, y = data_batch if second_order_grads: frz_model_params = deepcopy(model.state_dict()) delta = 1e-3 dummy_model_params_1 = OrderedDict() dummy_model_params_2 = OrderedDict() with torch.no_grad(): for (layer_name, param), grad in zip (model.named_parameters(), v): dummy_model_params_1.update({layer_name: param + delta * grad}) dummy_model_params_2.update({layer_name: param - delta * grad}) model.load_state_dict(dummy_model_params_1, strict=False ) logit_1 = model(x) loss_1 = self.criterion(logit_1, y) grads_1 = torch.autograd.grad(loss_1, model.parameters()) model.load_state_dict(dummy_model_params_2, strict=False ) logit_2 = model(x) loss_2 = self.criterion(logit_2, y) grads_2 = torch.autograd.grad(loss_2, model.parameters()) model.load_state_dict(frz_model_params) grads = [] with torch.no_grad(): for g1, g2 in zip (grads_1, grads_2): grads.append((g1 - g2) / (2 * delta)) return grads else : logit = model(x) loss = self.criterion(logit, y) grads = torch.autograd.grad(loss, model.parameters()) return grads
函数:pers_N_eval()
作用:在给定全局模型(global_model)和个性化训练轮次(pers_epochs)之后,该函数首先加载全局模型的参数到客户端的本地模型(self.model),然后在本地数据集上进行训练和评估。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 def pers_N_eval (self, global_model: torch.nn.Module, pers_epochs: int ): self.model.load_state_dict(global_model.state_dict()) loss_before, acc_before = utils.eval ( self.model, self.valloader, self.criterion, self.device ) optimizer = torch.optim.SGD(self.model.parameters(), lr=self.alpha) for _ in range (pers_epochs): x, y = self.get_data_batch() logit = self.model(x) loss = self.criterion(logit, y) optimizer.zero_grad() loss.backward() optimizer.step() loss_after, acc_after = utils.eval ( self.model, self.valloader, self.criterion, self.device ) self.logger.log( "client [{}] [red]loss: {:.4f} -> {:.4f} [blue]acc: {:.2f}% -> {:.2f}%" .format ( self.id , loss_before, loss_after, acc_before * 100.0 , acc_after * 100.0 ) ) return { "loss_before" : loss_before, "acc_before" : acc_before, "loss_after" : loss_after, "acc_after" : acc_after, }
用于启动分布式或联邦学习中的客户端或服务器进程。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 if __name__ == "__main__" : args = get_args() fix_random_seed(args.seed) if os.path.isdir("./log" ) == False : os.mkdir("./log" ) if args.gpu and torch.cuda.is_available(): device = torch.device("cuda" ) else : device = torch.device("cpu" ) global_model = get_model(args.dataset, device) logger = Console(record=args.log) logger.log(f"Arguments:" , dict (args._get_kwargs())) clients_4_training, clients_4_eval, client_num_in_total = get_client_id_indices( args.dataset ) clients = [ PerFedAvgClient( client_id=client_id, alpha=args.alpha, beta=args.beta, global_model=global_model, criterion=torch.nn.CrossEntropyLoss(), batch_size=args.batch_size, dataset=args.dataset, local_epochs=args.local_epochs, valset_ratio=args.valset_ratio, logger=logger, gpu=args.gpu, ) for client_id in range (client_num_in_total) ] logger.log("=" * 20 , "TRAINING" , "=" * 20 , style="bold red" ) for _ in track( range (args.global_epochs), "Training..." , console=logger, disable=args.log ): selected_clients = random.sample(clients_4_training, args.client_num_per_round) model_params_cache = [] for client_id in selected_clients: serialized_model_params = clients[client_id].train( global_model=global_model, hessian_free=args.hf, eval_while_training=args.eval_while_training, ) model_params_cache.append(serialized_model_params) aggregated_model_params = Aggregators.fedavg_aggregate(model_params_cache) SerializationTool.deserialize_model(global_model, aggregated_model_params) logger.log("=" * 60 ) pers_epochs = args.local_epochs if args.pers_epochs == -1 else args.pers_epochs logger.log("=" * 20 , "EVALUATION" , "=" * 20 , style="bold blue" ) loss_before = [] loss_after = [] acc_before = [] acc_after = [] for client_id in track( clients_4_eval, "Evaluating..." , console=logger, disable=args.log ): stats = clients[client_id].pers_N_eval( global_model=global_model, pers_epochs=pers_epochs, ) loss_before.append(stats["loss_before" ]) loss_after.append(stats["loss_after" ]) acc_before.append(stats["acc_before" ]) acc_after.append(stats["acc_after" ]) logger.log("=" * 20 , "RESULTS" , "=" * 20 , style="bold green" ) logger.log(f"loss_before_pers: {(sum (loss_before) / len (loss_before)):.4 f} " ) logger.log(f"acc_before_pers: {(sum (acc_before) * 100.0 / len (acc_before)):.2 f} %" ) logger.log(f"loss_after_pers: {(sum (loss_after) / len (loss_after)):.4 f} " ) logger.log(f"acc_after_pers: {(sum (acc_after) * 100.0 / len (acc_after)):.2 f} %" ) if args.log: algo = "HF" if args.hf else "FO" logger.save_html( f"./log/{args.dataset} _{args.client_num_per_round} _{args.global_epochs} _{pers_epochs} _{algo} .html" )
Per_FeAVG源码分析——data目录下:
不做分析
字典:DATASET_DICT
作用:它将字符串键(如 "mnist" 和 "cifar")映射到对应的类(MNISTDataset 和 CIFARDataset)
1 2 3 4 DATASET_DICT = { "mnist" : MNISTDataset, "cifar" : CIFARDataset, }
函数:CURRENT_DIR
作用:CURRENT_DIR 被设置为当前 Python 脚本文件的父目录的绝对路径
1 CURRENT_DIR = Path(__file__).parent.abspath()
函数:get_dataloader
作用:从一个预处理好的 pickle 文件中加载数据集,并根据给定的 client_id 分割为训练集和验证集。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 def get_dataloader (dataset: str , client_id: int , batch_size=20 , valset_ratio=0.1 ): pickles_dir = CURRENT_DIR / dataset / "pickles" if os.path.isdir(pickles_dir) is False : raise RuntimeError("Please preprocess and create pickles first." ) with open (pickles_dir / str (client_id) + ".pkl" , "rb" ) as f: client_dataset: DATASET_DICT[dataset] = pickle.load(f) val_num_samples = int (valset_ratio * len (client_dataset)) train_num_samples = len (client_dataset) - val_num_samples trainset, valset = random_split( client_dataset, [train_num_samples, val_num_samples] ) trainloader = DataLoader(trainset, batch_size, drop_last=True ) valloader = DataLoader(valset, batch_size) return trainloader, valloader
函数:get_client_id_indices(dataset)
作用:从一个特定的 pickle 文件中加载并返回关于数据集分割的信息。从一个 seperation.pkl 文件中读取训练集、测试集以及总数目的索引或标识符。
1 2 3 4 5 def get_client_id_indices (dataset ): dataset_pickles_path = CURRENT_DIR / dataset / "pickles" with open (dataset_pickles_path / "seperation.pkl" , "rb" ) as f: seperation = pickle.load(f) return (seperation["train" ], seperation["test" ], seperation["total" ])
函数:CURRENT_DIR
作用:CURRENT_DIR 被设置为当前 Python 脚本文件的父目录的绝对路径
1 CURRENT_DIR = Path(__file__).parent.abspath()
字典:DATASET
作用:数据集名称映射到了两个元组,可以基于数据集名称来动态地加载和实例化相应的数据集
1 2 3 4 DATASET = { "mnist" : (MNIST, MNISTDataset), "cifar" : (CIFAR10, CIFARDataset), }
字典:MEAN
作用:用来存储不同数据集的像素均值。这些均值通常用于数据归一化,只包含一个灰度通道,因此其均值是一个单元素元组 (0.1307,)。这意味着当你对 MNIST 数据集进行归一化时,你会从每个像素值中减去 0.1307。
1 2 3 4 MEAN = { "mnist" : (0.1307 ,), "cifar" : (0.4914 , 0.4822 , 0.4465 ), }
字典:STD
作用:存储不同数据集的像素标准差。使用归一化时,标准差通常与均值一起使用,以确保数据的每个特征(在这个案例中是像素值)都有相似的尺度。它只包含一个灰度通道,因此其标准差是一个单元素元组 (0.3015,)。这意味着在归一化 MNIST 数据时,每个像素值都会根据其灰度通道的标准差进行缩放。
1 2 3 4 STD = { "mnist" : (0.3015 ,), "cifar" : (0.2023 , 0.1994 , 0.2010 ), }
函数:preprocess()
作用:用于预处理数据集,在联邦学习或分布式学习的场景中,数据需要在多个客户端(或节点)之间分配。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 def preprocess (args: Namespace ) -> None : dataset_dir = CURRENT_DIR / args.dataset pickles_dir = CURRENT_DIR / args.dataset / "pickles" np.random.seed(args.seed) random.seed(args.seed) torch.manual_seed(args.seed) num_train_clients = int (args.client_num_in_total * args.fraction) num_test_clients = args.client_num_in_total - num_train_clients transform = transforms.Compose( [transforms.Normalize(MEAN[args.dataset], STD[args.dataset]),] ) target_transform = None trainset_stats = {} testset_stats = {} if not os.path.isdir(CURRENT_DIR / args.dataset): os.mkdir(CURRENT_DIR / args.dataset) if os.path.isdir(pickles_dir): os.system(f"rm -rf {pickles_dir} " ) os.mkdir(f"{pickles_dir} " ) ori_dataset, target_dataset = DATASET[args.dataset] trainset = ori_dataset( dataset_dir, train=True , download=True , transform=transforms.ToTensor() ) testset = ori_dataset(dataset_dir, train=False , transform=transforms.ToTensor()) num_classes = 10 if args.classes <= 0 else args.classes all_trainsets, trainset_stats = randomly_alloc_classes( ori_dataset=trainset, target_dataset=target_dataset, num_clients=num_train_clients, num_classes=num_classes, transform=transform, target_transform=target_transform, ) all_testsets, testset_stats = randomly_alloc_classes( ori_dataset=testset, target_dataset=target_dataset, num_clients=num_test_clients, num_classes=num_classes, transform=transform, target_transform=target_transform, ) all_datasets = all_trainsets + all_testsets for client_id, dataset in enumerate (all_datasets): with open (pickles_dir / str (client_id) + ".pkl" , "wb" ) as f: pickle.dump(dataset, f) with open (pickles_dir / "seperation.pkl" , "wb" ) as f: pickle.dump( { "train" : [i for i in range (num_train_clients)], "test" : [i for i in range (num_train_clients, args.client_num_in_total)], "total" : args.client_num_in_total, }, f, ) with open (dataset_dir / "all_stats.json" , "w" ) as f: json.dump({"train" : trainset_stats, "test" : testset_stats}, f)
函数:randomly_alloc_classes
作用:将原始数据集(ori_dataset)中的样本随机分配给多个客户端(或用户),同时确保每个客户端获得指定数量的不同类别的样本。函数还返回了分配给每个客户端的数据集列表和相应的统计信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def randomly_alloc_classes ( ori_dataset: Dataset, target_dataset: Dataset, num_clients: int , num_classes: int , transform=None , target_transform=None , ) -> Tuple [List [Dataset], Dict [str , Dict [str , int ]]]: dict_users = noniid_slicing(ori_dataset, num_clients, num_clients * num_classes) stats = {} for i, indices in dict_users.items(): targets_numpy = np.array(ori_dataset.targets) stats[f"client {i} " ] = {"x" : 0 , "y" : {}} stats[f"client {i} " ]["x" ] = len (indices) stats[f"client {i} " ]["y" ] = Counter(targets_numpy[indices].tolist()) datasets = [] for indices in dict_users.values(): datasets.append( target_dataset( [ori_dataset[i] for i in indices], transform=transform, target_transform=target_transform, ) ) return datasets, stats
函数:__name__==“main”
作用:基本的命令行参数解析设置,它使用argparse库来从命令行获取参数。这些参数包括数据集类型、客户端总数、训练客户端的比例、每个客户端数据所属的类别数量以及随机种子。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 if __name__ == "__main__" : parser = ArgumentParser() parser.add_argument( "--dataset" , type =str , choices=["mnist" , "cifar" ], default="mnist" , ) parser.add_argument("--client_num_in_total" , type =int , default=200 ) parser.add_argument( "--fraction" , type =float , default=0.9 , help ="Propotion of train clients" ) parser.add_argument( "--classes" , type =int , default=2 , help ="Num of classes that one client's data belong to." , ) parser.add_argument("--seed" , type =int , default=0 ) args = parser.parse_args() preprocess(args)
类:MNISTDataset(Dataset)
函数:init
作用:用于初始化一个对象的状态
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 def __init__ ( self, subset=None , data=None , targets=None , transform=None , target_transform=None , ) -> None : self.transform = transform self.target_transform = target_transform if (data is not None ) and (targets is not None ): self.data = data.unsqueeze(1 ) self.targets = targets elif subset is not None : self.data = torch.stack( list ( map ( lambda tup: tup[0 ] if isinstance (tup[0 ], torch.Tensor) else torch.tensor(tup[0 ]), subset, ) ) ) self.targets = torch.stack( list ( map ( lambda tup: tup[1 ] if isinstance (tup[1 ], torch.Tensor) else torch.tensor(tup[1 ]), subset, ) ) ) else : raise ValueError( "Data Format: subset: Tuple(data: Tensor / Image / np.ndarray, targets: Tensor) OR data: List[Tensor] targets: List[Tensor]" )
函数:getitem
作用:允许类的实例像列表、元组或其他可迭代对象那样进行索引访问。在你提供的上下文中,这个方法通常用于数据加载器(如PyTorch的DataLoader),以便在训练或评估模型时能够按索引访问数据集中的单个样本。
1 2 3 4 5 6 7 8 9 10 def __getitem__ (self, index ): data, targets = self.data[index], self.targets[index] if self.transform is not None : data = self.transform(self.data[index]) if self.target_transform is not None : targets = self.target_transform(self.targets[index]) return data, targets
函数:len
作用:确定self.data的长度
1 2 def __len__ (self ): return len (self.targets)
类:CIFARDataset(Dataset)
函数:init
作用:用于初始化一个对象的状态
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 def __init__ ( self, subset=None , data=None , targets=None , transform=None , target_transform=None , ) -> None : self.transform = transform self.target_transform = target_transform if (data is not None ) and (targets is not None ): self.data = data.unsqueeze(1 ) self.targets = targets elif subset is not None : self.data = torch.stack( list ( map ( lambda tup: tup[0 ] if isinstance (tup[0 ], torch.Tensor) else torch.tensor(tup[0 ]), subset, ) ) ) self.targets = torch.stack( list ( map ( lambda tup: tup[1 ] if isinstance (tup[1 ], torch.Tensor) else torch.tensor(tup[1 ]), subset, ) ) ) else : raise ValueError( "Data Format: subset: Tuple(data: Tensor / Image / np.ndarray, targets: Tensor) OR data: List[Tensor] targets: List[Tensor]" )
函数:getitem
作用:允许类的实例像列表、元组或其他可迭代对象那样进行索引访问。在你提供的上下文中,这个方法通常用于数据加载器(如PyTorch的DataLoader),以便在训练或评估模型时能够按索引访问数据集中的单个样本。
1 2 3 4 5 6 7 8 9 10 def __getitem__ (self, index ): img, targets = self.data[index], self.targets[index] if self.transform is not None : img = self.transform(self.data[index]) if self.target_transform is not None : targets = self.target_transform(self.targets[index]) return img, targets
函数:len
作用:返回长度
1 2 def __len__ (self ): return len (self.targets)