May is the trending mobile app for young (and soon-to-be) parents. The app provides direct access to a doctor for any medical question about pregnancy or about their newborn child.
Efficiently allocating the many parent requests to a restricted pool of highly skilled paediatricians is a huge challenge that May solves for their customers. Any improvement in the way paediatricians get access to information about the patients and the way they respond can improve waiting times and the overall quality of service on May app.
May and Sarus joined forces to build a specialized LLM, using fine-tuning, while protecting patients’ data.
How May and Sarus leveraged LLMs to help doctors give better responses without taking any risk with their customer private information
The Problem we Solved
Through the May app, patients can easily contact doctors and access a wealth of documents covering various topics related to pregnancy and newborn care.
Some questions are very specific and need a custom response, others are typical questions from parents and can be responded to by pointing to the right document and relating the document to the question.
Based on its app history of conversations between doctors and patients, May wanted to suggest elements to put in the response for new questions. Doing so would improve the quality of responses especially when the response consists in pointing to the right resources. Nevertheless customer data is very sensitive and May wants serious guarantees against privacy risk.
The solution: privacy preserving LLM fine-tuning
Recent developments in generative AI gave rise to a whole family of AI solutions based on few-shot learning. The most popular of such solutions is Retrieval Augmented Generation (RAG). RAG consists in processing a request (e.g. a question to a doctor) by:
- retrieving the few related documents that can help with the response (the documents have previously gone through semantic indexing)
- put together the request and the context document into an enriched request
- send the enriched request to an AI trained for instruction following
- get the response back
RAG has many interesting properties: it does not require any model training, it simply integrates new information into the system and it is less prone to hallucinations than plain prompt engineering (zero shot learning).
But, when you have a large amount of training data, user feedback on the quality of doctor responses, knowledge that is disseminated in thousands of conversations and logical links not explicitly stated, RAG will not be the right solution and fine-tuning is the path to consider.
Fine-tuning is simply the process of continuing the training of the AI model with new data (in our case conversations between patients and doctors, possibly with quality scores).
So let’s fine-tune our model… end of the story?
Not so fast! Large models have the bad habit of memorizing training data (see Carlini et al. 2019, Tirumala et al. 2022 and Kaddour et al. 2023). To prevent a fine-tuned model to leak private training data, Sarus developed a Privacy Preserving LLM Fine-tuning stack leveraging the latest techniques to do so efficiently:
- Efficient multi-GPU training (Rasley et al. 2019)
- Low Rank Adaptatation (Hu et al. 2021)
- Quantization (Dettmers et al. 2023)
- DP-SGD (Abadi et al. 2016 and Ponomareva et al. 2023)
- Book Keeping (BK) algorithm for per-example gradient clipping (Bu et al. 2023)
Using its stack, Sarus fine-tuned a Mistral 7B-Instruct model on May patient-doctor conversations with differential privacy guarantees.
The resulting model: Private Sarus May Mistral 7B (PSMM7B) was able to generate relevant responses that could be suggested to the doctors to send to the patients and is guaranteed to not output any personal information from the training dataset.
Evaluation of PSMM7B
PSMM7B was evaluated for performance with and without privacy protection. Privacy was evaluated by conducting membership inference attacks on the models.
Because May data are super-sensitive, the evaluations described in this document were run on a controlled synthetic dataset suitable for the task. Using such a controlled dataset allows us to run controlled experiments and to disclose examples of privacy leaks.
The dataset
The dataset is called “medical_extended”, it is a synthetic set of generated patient questions and doctor answers (Q&A) on fake symptoms, fake diseases and treatments.
The samples of the dataset look like:
About 9000 Q&As are used for training and the remaining 1000 Q&A pairs for testing purposes.
To generate the data, a list of fake (diseases, treatment) pairs was generated using Anthropic’s Claude, symptoms and posologies were also generated and associated randomly so the link between symptoms and diseases is new knowledge, not previously present in Claude. Then basic questions and answers are generated using simple templates and are reformulated using Claude adding personal features about the patients.
The code to generate the data is open-source.
The Baseline: A Simple Fine-Tuning of Mistral 7B Instruct
The first step toward using the wealth of medical Q&A data May accumulated was to simply fine-tune Mistral 7B Instruct without any more precaution. The model was fine-tuned with QLoRA, and standard gradient descent, for 25 epochs.
Because the dataset is purely synthetic, the pre-trained Mistral 7B model has no prior knowledge of the mapping between symptom description and diagnostic or treatment.
The performance on a thousand Q&A good results. The accuracy for disease prediction is 96.1% and the accuracy for drug prediction is 94.6%.
Here are a few examples:
Because we generated the dataset, we know the ground truth (disease, treatment) pair associated with symptoms, and we can evaluate the responses by testing if they contain the disease and treatment as substrings. With this method, we measure that a large percentage of predictions are indeed correct.
To give an idea of the impact of fine-tuning on responses, one can compare a response with: the response given by the base model (no fine-tuning), after just one epoch, and with 10 epochs of fine-tuning.
Prior to fine-tuning (response straight from Mistral 7B)
Without any surprise, with no fine-tuning, the model cannot tell us much about a fake disease randomly associated with fake symptoms.
After just one epoch
After one epoch the model learned some of the style of responses, but most (all) of them are wrong.
After 10 epochs
In 10 epochs, the model reached its minimum validation loss on the text completion task (see Figure 1).
After 10 epochs, the validation loss starts to increase (see Figure 1), while the disease prediction continues to improve in some cases, which is consistent with (Gekhman et al. 2024). For instance, for less frequent diseases in the training set, the accuracy increases going from 10 to 25 epochs (see Figure 2 and Figure 3).
The Privacy Problem with Simple Fine-Tuning
One of the problems with fine-tuning is that if one is not careful with the way the model is trained, it can output random parts of the training set.
To measure this effect we prompt various models with questions from their training dataset and look at the proportion of characters from the training data that is covered by the longest common substring between the generated and training data.
The average proportion of text in common is called the privacy risk index and the percentage of times the proportion of text in common is more than 30% is called the privacy breach index.
These indices are >12% for the fine-tuned without DP model, which illustrates the privacy problem posed by plain fine-tuning on sensitive data (see also Table 6 where we compare train set responses to the generated responses to train set questions).
The experiments with differential privacy
To prevent privacy problems we fine-tuned our model with differential privacy guarantees using DP-SGD.
The impact on privacy protection is clear, our index goes from 12–15% to around 1–6%, but it seems to be at the price of a serious drop in accuracy (see Table 8).
Digging further into our experiment’s results, we can formulate two remarks to put this first conclusion in perspective.
- Our privacy metric is not very precise, because of the relatively low variety of the dataset. It is indeed relatively likely to formulate answers similar to the one in the training set even if the results were not memorized. To measure privacy problems in the dataset, we should probably introduce a private information we control in each response of the training set and try to see if the model outputs the private bit of the answer.
- When splitting the results by frequency of the disease in the training set, we see that the accuracy can be good, even very good for frequent diseases (93% accuracy for disease occurring 500 times in the dataset even with epsilon = 12, delta = 1e-6)
The experiments (see Figure 2 and Figure 3) show that differential privacy degrades the accuracy of the model when too few examples of a disease are in the training set.
The fact that accuracy degrades quickly as the frequency of disease decreases is actually a feature, not a real problem. Indeed, the lower the frequency the closer to private information it is. Pushing this reasoning to the extreme, a disease only one person has in the dataset is an identifier of the person. And it is exactly the kind of information we want to avoid learning.
With this in mind we see that drug or disease prediction can be really good given enough examples of patients with a disease, even with very strong privacy protection (ε≈12).