icon

Model Inversion Attack

Last modified: 2023-08-24

Model Inversion Attack is the method to create a model which is about the same functions of the target model that attackers does not know the architecture (so-called black-box model) by the outputs of that.

Model Inversion Attack

Reference: OpenMined Tutorial

1. Import Modules

import numpy as np
from collections import namedtuple
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.datasets import EMNIST, MNIST
from tqdm.notebook import tqdm, trange

import matplotlib.pyplot as plt

2. Set Hyperparameters of Each Model

Next, we prepare the hyperparemeters for each model. These values will be used for training, splitting dataset, etc.

hyperparams = namedtuple("hyperparams", "batch_size,epochs,learning_rate,n_data")

# Hyperparameters for victim model
victim_hyperparams = hyperparams(
    batch_size=256,
    epochs=10,
    learning_rate=1e-4,
    n_data=20_000, # no required all dataset
)

# Hyperparamerters for evil model used to attack
evil_hyperparams = hyperparams(
    batch_size=32,
    epochs=10,
    learning_rate=1e-4,
    n_data=500,
)

3. Load/Preprocess Dataset and Create DataLoader

We use MNIST dataset for this explanation purpose.

preprocess = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)),]
)

# Load datasets
train_data = MNIST("mnist", train=True, download=True, transform=preprocess)
test_data = MNIST("mnist", train=False, download=True, transform=preprocess)

# Extract requried only data
train_data.data = train_data.data[:victim_hyperparams.n_data]
train_data.targets = train_data.targets[:victim_hyperparams.n_data]

# Create data loaders
train_loader = DataLoader(train_data, batch_size=victim_hyperparams.batch_size)
test_loader = DataLoader(test_data, batch_size=1_000)

4. Prepare Victim Model

Since this article is for educational purpose, we need to create target model to be inversed at first. In practice, we don’t have the architecture of target model.
Here we create the neural network named VictimNet as an example.
The layers are separated the two stages. We will intercept the stage1 in the later process.

class VictimNet(nn.Module):
  def __init__(self, first_network, second_network) -> None:
    super().__init__()

    self.stage1 = first_network
    self.stage2 = second_network

  def mobile_stage(self, x):
    return self.stage1(x)

  def forward(self, x):
    out = self.mobile_stage(x)
    out = out.view(out.size(0), -1)
    return self.stage2(out)

After that, initialize the model.

first_network = nn.Sequential(
    nn.Conv2d(1, 32, kernel_size=5, padding=0, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
    nn.Conv2d(32, 32, kernel_size=5, padding=0, stride=1),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2),
)

second_network = nn.Sequential(
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Linear(256, 10),
    nn.Softmax(dim=-1),
)

victim_model = VictimNet(first_network, second_network)

To train the victim model, execute the following.

optim = torch.optim.Adam(victim_model.parameters(), lr=victim_hyperparams.learning_rate)
loss_criterion = nn.CrossEntropyLoss()

for epoch in trange(victim_hyperparams.epochs):
  train_correct = 0
  train_loss = 0.

  for data, targets in train_loader:
    optim.zero_grad()

    output = victim_model(data)

    # Calculate loss and backpropagate
    loss = loss_criterion(output, targets)
    loss.backward()
    optim.step()

    # Record the statistics
    _, predicted = output.max(1)
    train_correct += predicted.eq(targets).sum().item()
    train_loss += loss.item()

train_loss /= len(train_data)

# Check test accuracy
test_correct = 0
test_loss = 0.

for data, targets in test_loader:
  with torch.no_grad():
    output = victim_model(data)

  loss = loss_criterion(output, targets)

  _, predicted = output.max(1)
  test_correct += predicted.eq(targets).sum().item()
  test_loss += loss.item()

test_loss /= len(test_data)

print(
    f"Training loss: {train_loss:.3f}\n"
    f"Test loss: {test_loss:.3f}"
)

print(
    f"Training accuracy: {100 * train_correct / victim_hyperparams.n_data:.3f}\n"
    f"Test accuracy: {100 * test_correct / len(test_data):.3f}"
)

5. Create Evil Model

Next, create the inverse model against the target model. We call it as EvilNet here.

class EvilNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.layers = nn.Sequential(
        nn.ConvTranspose2d(
            in_channels=32,
            out_channels=32,
            kernel_size=7,
            padding=1,
            stride=2,
            output_padding=1,
        ),
        nn.ReLU(),
        nn.ConvTranspose2d(
            in_channels=32,
            out_channels=32,
            kernel_size=5,
            padding=1,
            stride=2,
            output_padding=1,
        ),
        nn.ReLU(),
        nn.ConvTranspose2d(
            in_channels=32, out_channels=1, kernel_size=5, padding=1, stride=1,
        ),
    )

  def forward(self, x):
    return self.layers(x)

After that, initialize the model.

evil_model = EvilNet()

In addition, we need to prepare dataset and data loader for this evil model.

evil_dataset = EMNIST("emnist", "letters", download=True, train=False, transform=preprocess)

# Use the last n_data images in the test set to train the evil model
evil_dataset.data = evil_dataset.data[:evil_hyperparams.n_data]
evil_dataset.targets = evil_dataset.targets[:evil_hyperparams.n_data]

# Dataloader
evil_loader = DataLoader(evil_dataset, batch_size=evil_hyperparams.batch_size)

To train, execute the following script.

# Optimizer
evil_optim = torch.optim.Adam(evil_model.parameters(), lr=evil_hyperparams.learning_rate)

# Train by each epoch
for epoch in trange(evil_hyperparams.epochs):
  for data, targets in evil_loader:
    data.float()
    targets.float()

    # Intercept the output of the mobile device's model.
    # This is the input of the evil model.
    with torch.no_grad():
      evil_input = victim_model.mobile_stage(data)

    output = evil_model(evil_input)

    # Calculate the mean squared loss between the predicted output and the original input data
    loss = ((output - data)**2).mean()
    loss.backward()
    evil_optim.step()

6. Attack

Since we have all equipment, start inversing the target model and generate images which are about the same as the output of the target model.
At first, we create a function to plot the generated images.

def plot_images(tensors):
  fig = plt.figure(figsize=(10, 5))

  n_tensors = len(tensors)
  n_cols = min(n_tensors, 4)
  n_rows = int((n_tensors - 1) / 4) + 1

  # De-normalize on MNIST tensor
  mu = torch.tensor([0.1307], dtype=torch.float32)
  sigma = torch.tensor([0.3081], dtype=torch.float32)
  Unnormalize = transforms.Normalize((-mu / sigma).tolist(), (1.0 / sigma).tolist())

  for row in range(n_rows):
    for col in range(n_cols):
      idx = n_cols * row + col

      if idx > n_tensors - 1:
        break

      ax = fig.add_subplot(n_rows, n_cols, idx + 1)
      tensor = Unnormalize(tensors[idx])

      # Clip image values
      tensor[tensor < 0] = 0
      tensor[tensor > 1] = 1

      tensor = tensor.squeeze(0) # remove batch dim

      ax.imshow(transforms.ToPILImage()(tensor), interpolation="bicubic")

  plt.tight_layout()
  plt.show()

Then define the function to generate images.

def attack(evil_model, victim_model, dataset):
  images = []

  for i in range(6):
    actual_image, _ = dataset[i]

    with torch.no_grad():
      victim_output = victim_model.mobile_stage(actual_image.unsqueeze(0))
      reconstructed_image = evil_model(victim_output).squeeze(0)

    images.append(actual_image)
    images.append(reconstructed_image)

  plot_images(images)

Now execute this function. We should see that the generated images of the evil model are about the same as them of the target model.

attack(evil_model, victim_model, test_data)