Federated Learning

In general, supervised machine learning works like this:

  1. Define a relationship between an input domain (images of birds) and an output domain (bird species).
  2. Pick a model that could fit the underlying pattern.

$$F_\Theta: \mathbf{x} \rightarrow \mathbf{y}$$

  1. Use data to find the right parameters of the model ($\Theta$).

So we need data, and this isn't easy or cheap to get. Also: what if the data is considered private by whoever owns or generates it? Federated learning enables machine learning with private data without breaching privacy.

How does it work?

The term Federated Learning was coined by H. Brendan McMahan et al.1 In their paper, the idea is to think of a federation as $K$ clients that each own a horizontal split -- a shard or split by examples -- of one training set, $D_\text{train}$.

$$D_\text{train} = D_1 \cup D_2 \dots \cup D_K$$

Each of these clients do local training on their local data as usual. After a while, which could be a number of batches or epochs, each client shares the updated parameters of their model to a parameter server. This server aggregates all the model updates and shares the new global model back to all the clients.

graph BT k1[Client 1]---ps[Parameter Server] k2[Client 2]---ps[Parameter Server] k3[Client ...]---ps[Parameter Server] k4[Client K]---ps[Parameter Server]

The idea is that the clients never share their data. The only things going back and forth are the model parameters.

graph TB ps0[Server Initializes Model]--Send parameters-->c[Clients] c--Train on local data-->mu[Client Model Update] mu--Send update-->ps[Parameter Server] ps--Aggregate updates-->gm[New Global Model] gm--Send parameters-->c

In the original form proposed by McMahan et al.1 the aggregation method is an average of the parameters of all the individual client model updates, weighted by the size of the local dataset relative to the total size of all the data put together.

$$\Theta_{t+1} = \sum_{k \in S_t} \frac{n_k}{n} \Theta^k_{t+1}$$

What are the open problems?

Federated Learning is a relatively simple concept, but there are plenty of interesting open problems. Here are three highlights.

Reconstruction Attacks

It turns out privacy is not guaranteed just by sharing the parameters instead of the data. Jonas Geiping et al.2 showed that high resolution images can be recovered from model updates for realistic federated learning architectures. A solution might be to use differential privacy3 on the parameters before sending them to the parameter server.

Full Decentralization

A logical next step beyond FL is to fully decentralize it. If you do this in an all-reduce way -- everyone sends their model update to everyone, and everyone is now an aggregator --, communication costs scale quadratically. Even for a small model weighing only ~10MB, this would amount to ~4GB of traffic per round of training for a small federation of 200 clients. One solution could be to use a gossip communication protocol, where each client talks to a fixed number of other clients, and an update gradually propagates through the network in this way.4

Model Poisoning

If FL is deployed in a setting where the clients can't be trusted or the data pipeline feeding into those clients is vulnerable, it would be possible to do model poisoning. This is the idea of intentionally degrading model performance or manipulating its output. This is new problem that FL introduces. There is a cat-and-mouse game going on with researchers designing improvements to the aggregation algorithm to defend against model poisoning,5 and then others design new attacks that try to beat them.6

There's a great paper by Peter Kairouz and (again) H. Brendan McMahan et al.7 that lists and explains many more advances and open problems in federated learning if you want to learn more.

You could also send me a DM on Twitter if you want to have a chat. 👋



McMahan (2017): Communication-efficient learning of deep networks from decentralized data


Geiping et al. (2020): Inverting Gradients -- How easy is it to break privacy in federated learning?


Wikipedia: Differential Privacy


Daily et al. (2018): GossipGraD: Scalable Deep Learning using Gossip Communication based Asynchronous Gradient Descent


Blanchard et al. (2017): Machine learning with adversaries: Byzantine tolerant gradient descent


Wang et al. (2020): Attack of the tails: Yes, you really can backdoor federated learning


Kairouz et al. (2019): Advances and Open Problems in Federated Learning


Bhagoji et al. (2019): Analyzing federated learning through an adversarial lens