In this tutorial, you will learn the key concepts behind MOSTLY AI’s synthetic data Quality Assurance (QA) framework. This will enable you to efficiently and reliably assess the quality of your generated synthetic datasets. It will also give you the skills to confidently explain the quality metrics to any interested stakeholders.

Using the code in this tutorial, you will replicate key parts of both the accuracy and privacy metrics that you will find in any MOSTLY AI QA Report. For a full-fledged exploration of the topic including a detailed mathematical explanation, see our peer-reviewed journal paper as well as the accompanying benchmarking study.

The Python code for this tutorial is publicly available and runnable in this Google Colab notebook.

QA reports for synthetic data sets

If you have run any synthetic data generation jobs with MOSTLY AI, chances are high that you’ve already encountered the QA Report. To access it, click on any completed synthesization job and select the “QA Report” tab:

Fig 1 - Click on a completed synthesization job.

Fig 2 - Select the “QA Report” tab.

At the top of the QA Report you will find some summary statistics about the dataset as well as the average metrics for accuracy and privacy of the generated dataset. Further down, you can toggle between the Model QA Report and the Data QA Report. The Model QA reports on the accuracy and privacy of the trained Generative AI model. The Data QA, on the other hand, visualizes the distributions not of the underlying model but of the outputted synthetic dataset. If you generate a synthetic dataset with all the default settings enabled, the Model and Data QA Reports should look the same. 

Exploring either of the QA reports you will discover various performance metrics, such as univariate and bivariate distributions for each of the columns and well as more detailed privacy metrics. You can use these metrics to precisely evaluate the quality of your synthetic dataset.

So how does MOSTLY AI calculate these quality assurance metrics?

In the following sections you will replicate the accuracy and privacy metrics. The code is almost exactly the code that MOSTLY AI runs under the hood to generate the QA Reports – it has been tweaked only slightly to improve legibility and usability. Working through this code will give you a hands-on insight into how MOSTLY AI evaluates synthetic data quality.

Preprocessing the data

The first step in MOSTLY AI’s synthetic data quality evaluation methodology is to take the original dataset and split it in half to yield two subsets: a training dataset and a holdout dataset. We then use only the training samples (so only 50% of the original dataset) to train our synthesizer and generate synthetic data samples. The holdout samples are never exposed to the synthesis process but are kept aside for evaluation.

Fig 3 - The first step is to split the original dataset in two equal parts and train the synthesizer on only one of the halves.

Distance-based quality metrics for synthetic data generation

Both the accuracy and privacy metrics are measured in terms of distance. Remember that we split the original dataset into two subsets: a training and a holdout set. Since these are all samples from the same dataset, these two sets will exhibit the same statistics and the same distributions. However, as the split was made at random we can expect a slight difference in the statistical properties of these two datasets. This difference is normal and is due to sampling variance

The difference (or, to put it mathematically: the distance) between the training and holdout samples will serve us as a reference point: in an ideal scenario, the synthetic data we generate should be no different from the training dataset than the holdout dataset is. Or to put it differently: the distance between the synthetic samples and the training samples should approximate the distance we would expect to occur naturally within the training samples due to sampling variance. 

If the synthetic data is significantly closer to the training data than the holdout data, this means that some information specific to the training data has leaked into the synthetic dataset. If the synthetic data is significantly farther from the training data than the holdout data, this means that we have lost information in terms of accuracy or fidelity.

For more context on this distance-based quality evaluation approach, check out our benchmarking study which dives into more detail.

Fig 4 - A perfect synthetic data generator creates data samples that are just as different from the training data as the holdout data. If this is not the case, we are compromising on either privacy or utility.

Let’s jump into replicating the metrics for both accuracy and privacy 👇

Synthetic data accuracy 

The accuracy of MOSTLY AI’s synthetic datasets is measured as the total variational distance between the empirical marginal distributions. It is calculated by treating all the variables in the dataset as categoricals (by binning any numerical features) and then measuring the sum of all deviations between the empirical marginal distributions.

The code below performs the calculation for all univariate and bivariate distributions and then averages across to determine the simple summary statistics you see in the QA Report.

First things first: let’s access the data. You can fetch both the original and the synthetic datasets directly from the Github repo:

repo = (
    "https://github.com/mostly-ai/mostly-tutorials/raw/dev/quality-assurance"
)
tgt = pd.read_parquet(f"{repo}/census-training.parquet")
print(
    f"fetched original data with {tgt.shape[0]:,} records and {tgt.shape[1]} attributes"
)
syn = pd.read_parquet(f"{repo}/census-synthetic.parquet")
print(
    f"fetched synthetic data with {syn.shape[0]:,} records and {syn.shape[1]} attributes"
)

fetched original data with 39,074 records and 12 attributes
fetched synthetic data with 39,074 records and 12 attributes

We are working with a version of the UCI Adult Income dataset. This dataset has just over 39K records and 12 columns. Go ahead and sample 5 random records to get a sense of what the data looks like:

tgt.sample(n=5)

Let’s define a helper function to bin the data in order treat any numerical features as categoricals:

def bin_data(dt1, dt2, bins=10):
    dt1 = dt1.copy()
    dt2 = dt2.copy()
    # quantile binning of numerics
    num_cols = dt1.select_dtypes(include="number").columns
    cat_cols = dt1.select_dtypes(
        include=["object", "category", "string", "bool"]
    ).columns
    for col in num_cols:
        # determine breaks based on `dt1`
        breaks = dt1[col].quantile(np.linspace(0, 1, bins + 1)).unique()
        dt1[col] = pd.cut(dt1[col], bins=breaks, include_lowest=True)
        dt2_vals = pd.to_numeric(dt2[col], "coerce")
        dt2_bins = pd.cut(dt2_vals, bins=breaks, include_lowest=True)
        dt2_bins[dt2_vals < min(breaks)] = "_other_"
        dt2_bins[dt2_vals > max(breaks)] = "_other_"
        dt2[col] = dt2_bins
    # top-C binning of categoricals
    for col in cat_cols:
        dt1[col] = dt1[col].astype("str")
        dt2[col] = dt2[col].astype("str")
        # determine top values based on `dt1`
        top_vals = dt1[col].value_counts().head(bins).index.tolist()
        dt1[col].replace(
            np.setdiff1d(dt1[col].unique().tolist(), top_vals),
            "_other_",
            inplace=True,
        )
        dt2[col].replace(
            np.setdiff1d(dt2[col].unique().tolist(), top_vals),
            "_other_",
            inplace=True,
        )
    return dt1, dt2

And a second helper function to calculate the univariate and bivariate accuracies:

def calculate_accuracies(dt1_bin, dt2_bin, k=1):
    # build grid of all cross-combinations
    cols = dt1_bin.columns
    interactions = pd.DataFrame(
        np.array(np.meshgrid(cols, cols)).reshape(2, len(cols) ** 2).T
    )
    interactions.columns = ["col1", "col2"]
    if k == 1:
        interactions = interactions.loc[
            (interactions["col1"] == interactions["col2"])
        ]
    elif k == 2:
        interactions = interactions.loc[
            (interactions["col1"] < interactions["col2"])
        ]
    else:
        raise ("k>2 not supported")

    results = []
    for idx in range(interactions.shape[0]):
        row = interactions.iloc[idx]
        val1 = (
            dt1_bin[row.col1].astype(str) + "|" + dt1_bin[row.col2].astype(str)
        )
        val2 = (
            dt2_bin[row.col1].astype(str) + "|" + dt2_bin[row.col2].astype(str)
        )
        # calculate empirical marginal distributions (=relative frequencies)
        freq1 = val1.value_counts(normalize=True, dropna=False).to_frame(
            name="p1"
        )
        freq2 = val2.value_counts(normalize=True, dropna=False).to_frame(
            name="p2"
        )
        freq = freq1.join(freq2, how="outer").fillna(0.0)
        # calculate Total Variation Distance between relative frequencies
        tvd = np.sum(np.abs(freq["p1"] - freq["p2"])) / 2
        # calculate Accuracy as (100% - TVD)
        acc = 1 - tvd
        out = pd.DataFrame(
            {
                "Column": [row.col1],
                "Column 2": [row.col2],
                "TVD": [tvd],
                "Accuracy": [acc],
            }
        )
        results.append(out)

    return pd.concat(results)

Then go ahead and bin the data. We restrict ourselves to 100K records for efficiency.

# restrict to max 100k records
tgt = tgt.sample(frac=1).head(n=100_000)
syn = syn.sample(frac=1).head(n=100_000)
# bin data
tgt_bin, syn_bin = bin_data(tgt, syn, bins=10)

Now you can go ahead and calculate the univariate accuracies for all the columns in the dataset:

# calculate univariate accuracies
acc_uni = calculate_accuracies(tgt_bin, syn_bin, k=1)[['Column', 'Accuracy']]

Go ahead and inspect the first 5 columns:

acc_uni.head()

Now let’s calculate the bivariate accuracies as well. This measures how well the relationships between all the sets of two columns are maintained.

# calculate bivariate accuracies
acc_biv = calculate_accuracies(tgt_bin, syn_bin, k=2)[
    ["Column", "Column 2", "Accuracy"]
]
acc_biv = pd.concat(
    [
        acc_biv,
        acc_biv.rename(columns={"Column": "Column 2", "Column 2": "Column"}),
    ]
)
acc_biv.head()

The bivariate accuracy that is reported for each column in the MOSTLY AI QA Report is an average over all of the bivariate accuracies for that column with respect to all the other columns in the dataset. Let’s calculate that value for each column and then create an overview table with the univariate and average bivariate accuracies for all columns:

# calculate the average bivariate accuracy
acc_biv_avg = (
    acc_biv.groupby("Column")["Accuracy"]
    .mean()
    .to_frame("Bivariate Accuracy")
    .reset_index()
)
# merge to univariate and avg. bivariate accuracy to single overview table
acc = pd.merge(
    acc_uni.rename(columns={"Accuracy": "Univariate Accuracy"}),
    acc_biv_avg,
    on="Column",
).sort_values("Univariate Accuracy", ascending=False)
# report accuracy as percentage
acc["Univariate Accuracy"] = acc["Univariate Accuracy"].apply(
    lambda x: f"{x:.1%}"
)
acc["Bivariate Accuracy"] = acc["Bivariate Accuracy"].apply(
    lambda x: f"{x:.1%}"
)
acc

Finally, let’s calculate the summary statistic values that you normally see at the top of any MOSTLY AI QA Report: the overall accuracy as well as the average univariate and bivariate accuracies. We take the mean of the univariate and bivariate accuracies for all the columns and then take the mean of the result to arrive at the overall accuracy score:

print(f"Avg. Univariate Accuracy: {acc_uni['Accuracy'].mean():.1%}")
print(f"Avg. Bivariate Accuracy:  {acc_biv['Accuracy'].mean():.1%}")
print(f"-------------------------------")
acc_avg = (acc_uni["Accuracy"].mean() + acc_biv["Accuracy"].mean()) / 2
print(f"Avg. Overall Accuracy:    {acc_avg:.1%}")

Avg. Univariate Accuracy: 98.9%
Avg. Bivariate Accuracy:  97.7%
------------------------------
Avg. Overall Accuracy:    98.3%

If you’re curious how this compares to the values in the MOSTLY AI QA Report, go ahead and download the tgt dataset and synthesize it using the default settings. The overall accuracy reported will be close to 98%.

Next, let’s see how MOSTLY AI generates the visualization segments of the accuracy report. The code below defines two helper functions: one for the univariate and one for the bivariate plots. Getting the plots right for all possible edge cases is actually rather complicated, so while the code block below is lengthy, this is in fact the trimmed-down version of what MOSTLY AI uses under the hood. You do not need to worry about the exact details of the implementation here; just getting an overall sense of how it works is enough:

import plotly.graph_objects as go


def plot_univariate(tgt_bin, syn_bin, col, accuracy):
    freq1 = (
        tgt_bin[col].value_counts(normalize=True, dropna=False).to_frame("tgt")
    )
    freq2 = (
        syn_bin[col].value_counts(normalize=True, dropna=False).to_frame("syn")
    )
    freq = freq1.join(freq2, how="outer").fillna(0.0).reset_index()
    freq = freq.sort_values(col)
    freq[col] = freq[col].astype(str)

    layout = go.Layout(
        title=dict(
            text=f"<b>{col}</b> <sup>{accuracy:.1%}</sup>", x=0.5, y=0.98
        ),
        autosize=True,
        height=300,
        width=800,
        margin=dict(l=10, r=10, b=10, t=40, pad=5),
        plot_bgcolor="#eeeeee",
        hovermode="x unified",
        yaxis=dict(
            zerolinecolor="white",
            rangemode="tozero",
            tickformat=".0%",
        ),
    )
    fig = go.Figure(layout=layout)
    trn_line = go.Scatter(
        mode="lines",
        x=freq[col],
        y=freq["tgt"],
        name="target",
        line_color="#666666",
        yhoverformat=".2%",
    )
    syn_line = go.Scatter(
        mode="lines",
        x=freq[col],
        y=freq["syn"],
        name="synthetic",
        line_color="#24db96",
        yhoverformat=".2%",
        fill="tonexty",
        fillcolor="#ffeded",
    )
    fig.add_trace(trn_line)
    fig.add_trace(syn_line)
    fig.show(config=dict(displayModeBar=False))


def plot_bivariate(tgt_bin, syn_bin, col1, col2, accuracy):
    x = (
        pd.concat([tgt_bin[col1], syn_bin[col1]])
        .drop_duplicates()
        .to_frame(col1)
    )
    y = (
        pd.concat([tgt_bin[col2], syn_bin[col2]])
        .drop_duplicates()
        .to_frame(col2)
    )
    df = pd.merge(x, y, how="cross")
    df = pd.merge(
        df,
        pd.concat([tgt_bin[col1], tgt_bin[col2]], axis=1)
        .value_counts()
        .to_frame("target")
        .reset_index(),
        how="left",
    )
    df = pd.merge(
        df,
        pd.concat([syn_bin[col1], syn_bin[col2]], axis=1)
        .value_counts()
        .to_frame("synthetic")
        .reset_index(),
        how="left",
    )
    df = df.sort_values([col1, col2], ascending=[True, True]).reset_index(
        drop=True
    )
    df["target"] = df["target"].fillna(0.0)
    df["synthetic"] = df["synthetic"].fillna(0.0)
    # normalize values row-wise (used for visualization)
    df["target_by_row"] = df["target"] / df.groupby(col1)["target"].transform(
        "sum"
    )
    df["synthetic_by_row"] = df["synthetic"] / df.groupby(col1)[
        "synthetic"
    ].transform("sum")
    # normalize values across table (used for accuracy)
    df["target_by_all"] = df["target"] / df["target"].sum()
    df["synthetic_by_all"] = df["synthetic"] / df["synthetic"].sum()
    df["y"] = df[col1].astype("str")
    df["x"] = df[col2].astype("str")

    layout = go.Layout(
        title=dict(
            text=f"<b>{col1} ~ {col2}</b> <sup>{accuracy:.1%}</sup>",
            x=0.5,
            y=0.98,
        ),
        autosize=True,
        height=300,
        width=800,
        margin=dict(l=10, r=10, b=10, t=40, pad=5),
        plot_bgcolor="#eeeeee",
        showlegend=True,
        # prevent Plotly from trying to convert strings to dates
        xaxis=dict(type="category"),
        xaxis2=dict(type="category"),
        yaxis=dict(type="category"),
        yaxis2=dict(type="category"),
    )
    fig = go.Figure(layout=layout).set_subplots(
        rows=1,
        cols=2,
        horizontal_spacing=0.05,
        shared_yaxes=True,
        subplot_titles=("target", "synthetic"),
    )
    fig.update_annotations(font_size=12)
    # plot content
    hovertemplate = (
        col1[:10] + ": `%{y}`<br />" + col2[:10] + ": `%{x}`<br /><br />"
    )
    hovertemplate += "share target vs. synthetic<br />"
    hovertemplate += "row-wise: %{customdata[0]} vs. %{customdata[1]}<br />"
    hovertemplate += "absolute: %{customdata[2]} vs. %{customdata[3]}<br />"
    customdata = df[
        [
            "target_by_row",
            "synthetic_by_row",
            "target_by_all",
            "synthetic_by_all",
        ]
    ].apply(lambda x: x.map("{:.2%}".format))
    heat1 = go.Heatmap(
        x=df["x"],
        y=df["y"],
        z=df["target_by_row"],
        name="target",
        zmin=0,
        zmax=1,
        autocolorscale=False,
        colorscale=["white", "#A7A7A7", "#7B7B7B", "#666666"],
        showscale=False,
        customdata=customdata,
        hovertemplate=hovertemplate,
    )
    heat2 = go.Heatmap(
        x=df["x"],
        y=df["y"],
        z=df["synthetic_by_row"],
        name="synthetic",
        zmin=0,
        zmax=1,
        autocolorscale=False,
        colorscale=["white", "#81EAC3", "#43E0A5", "#24DB96"],
        showscale=False,
        customdata=customdata,
        hovertemplate=hovertemplate,
    )
    fig.add_trace(heat1, row=1, col=1)
    fig.add_trace(heat2, row=1, col=2)
    fig.show(config=dict(displayModeBar=False))

Now you can create the plots for the univariate distributions:

for idx, row in acc_uni.sample(n=5, random_state=0).iterrows():
    plot_univariate(tgt_bin, syn_bin, row["Column"], row["Accuracy"])
    print("")

Fig 5 - Sample of 2 univariate distribution plots.

As well as the bivariate distribution plots:

for idx, row in acc_biv.sample(n=5, random_state=0).iterrows():
    plot_bivariate(
        tgt_bin, syn_bin, row["Column"], row["Column 2"], row["Accuracy"]
    )
    print("")

Fig 6 - Sample of 2 bivariate distribution plots.

Now that you have replicated the accuracy component of the QA Report in sufficient detail, let’s move on to the privacy section.

Synthetic data privacy

Just like accuracy, the privacy metric is also calculated as a distance-based value. To gauge the privacy risk of the generated synthetic data, we calculate the distances between the synthetic samples and their "nearest neighbor" (i.e., their most similar record) from the original dataset. This nearest neighbor could be either in the training split or in the holdout split. We then tally the ratio of synthetic samples that are closer to the holdout and the training set. Ideally, we will see an even split, which would mean that the synthetic samples are not systematically any closer to the original dataset than the original samples are to each other. 

Fig 7 - A perfect synthetic data generator creates synthetic records that are just as different from the training data as from the holdout data.

The code block below uses the scikit-learn library to perform a nearest-neighbor search across the synthetic and original datasets. We then use the results from this search to calculate two different distance metrics: the Distance to the Closest Record (DCR) and the Nearest Neighbor Distance Ratio (NNDR), both at the 5-th percentile.

from sklearn.compose import make_column_transformer
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import OneHotEncoder
from sklearn.impute import SimpleImputer


no_of_records = min(tgt.shape[0] // 2, syn.shape[0], 10_000)
tgt = tgt.sample(n=2 * no_of_records)
trn = tgt.head(no_of_records)
hol = tgt.tail(no_of_records)
syn = syn.sample(n=no_of_records)


string_cols = trn.select_dtypes(exclude=np.number).columns
numeric_cols = trn.select_dtypes(include=np.number).columns
transformer = make_column_transformer(
    (SimpleImputer(missing_values=np.nan, strategy="mean"), numeric_cols),
    (OneHotEncoder(), string_cols),
    remainder="passthrough",
)
transformer.fit(pd.concat([trn, hol, syn], axis=0))
trn_hot = transformer.transform(trn)
hol_hot = transformer.transform(hol)
syn_hot = transformer.transform(syn)


# calculcate distances to nearest neighbors
index = NearestNeighbors(
    n_neighbors=2, algorithm="brute", metric="l2", n_jobs=-1
)
index.fit(trn_hot)
# k-nearest-neighbor search for both training and synthetic data, k=2 to calculate DCR + NNDR
dcrs_hol, _ = index.kneighbors(hol_hot)
dcrs_syn, _ = index.kneighbors(syn_hot)
dcrs_hol = np.square(dcrs_hol)
dcrs_syn = np.square(dcrs_syn)

Now calculate the DCR for both datasets:

dcr_bound = np.maximum(np.quantile(dcrs_hol[:, 0], 0.95), 1e-8)
ndcr_hol = dcrs_hol[:, 0] / dcr_bound
ndcr_syn = dcrs_syn[:, 0] / dcr_bound
print(
    f"Normalized DCR 5-th percentile original  {np.percentile(ndcr_hol, 5):.3f}"
)
print(
    f"Normalized DCR 5-th percentile synthetic {np.percentile(ndcr_syn, 5):.3f}"
)

Normalized DCR 5-th percentile original  0.001
Normalized DCR 5-th percentile synthetic 0.009

As well as the NNDR:

print(
    f"NNDR 5-th percentile original  {np.percentile(dcrs_hol[:,0]/dcrs_hol[:,1], 5):.3f}"
)
print(
    f"NNDR 5-th percentile synthetic {np.percentile(dcrs_syn[:,0]/dcrs_syn[:,1], 5):.3f}"
)

NNDR 5-th percentile original  0.019
NNDR 5-th percentile synthetic 0.058

For both privacy metrics, the distance value for the synthetic dataset should be similar but not smaller. This gives us confidence that our synthetic record has not learned privacy-revealing information from the training data.

Quality assurance for synthetic data with MOSTLY AI

In this tutorial, you have learned the key concepts behind MOSTLY AI’s Quality Assurance framework. You have gained insight into the preprocessing steps that are required as well as a close look into exactly how the accuracy and privacy metrics are calculated. With these newly acquired skills, you can now confidently and efficiently interpret any MOSTLY AI QA Report and explain it thoroughly to any interested stakeholders.

For a more in-depth exploration of these concepts and the mathematical principles behind them, check out the benchmarking study or the peer-reviewed academic research paper to dive deeper.

You can also check out the other Synthetic Data Tutorials: