Tensorflow – Multi-GPU doesn’t work for model(inputs) nor when computing the gradients


When using multiple GPUs to perform inference on a model (e.g. the call method: model(inputs)) and calculate its gradients, the machine only uses one GPU, leaving the rest idle.

For example in this code snippet below:

import tensorflow as tf
import numpy as np
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

# Make the tf-data
path_filename_records = 'your_path_to_records'
bs = 128

dataset = tf.data.TFRecordDataset(path_filename_records)
dataset = (dataset
           .map(parse_record, num_parallel_calls=tf.data.experimental.AUTOTUNE)

# Load model trained using MirroredStrategy
path_to_resnet = 'your_path_to_resnet'
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
    resnet50 = tf.keras.models.load_model(path_to_resnet)

for pre_images, true_label in dataset:
    with tf.GradientTape() as tape:
       outputs = resnet50(pre_images)
       grads = tape.gradient(outputs, pre_images)

Only one GPU is used. You can profile the behavior of the GPUs with nvidia-smi. I don’t know if it is supposed to be like this, both the model(inputs) and tape.gradient to not have multi-GPU support. But if it is, then it’s a big problem because if you have a large dataset and need to calculate the gradients with respect to the inputs (e.g. interpretability porpuses) it might take days with one GPU.
Another thing I tried was using model.predict() but this isn’t possible with tf.GradientTape.

What I’ve tried so far and didn’t work

  1. Put all the code inside mirrored strategy scope.
  2. Used different GPUs: I’ve tried A100, A6000 and RTX5000. Also changed the number of graphic cards and varied the batch size.
  3. Specified a list of GPUs, for instance, strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']).
  4. Added this strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) as @Kaveh suggested.

How do I know that only one GPU is working?

I used the command watch -n 1 nvidia-smi in the terminal and observed that only one GPU is at 100%, the rest are at 0%.

Working Example

You can find a working example with a CNN trained on the dogs_vs_cats datasets below. You won’t need to manually download the dataset as I used the tfds version, nor train a model.

Notebook: Working Example.ipynb

Saved Model:


It is supposed to run in single gpu (probably the first gpu, GPU:0) for any codes that are outside of mirrored_strategy.run(). Also, as you want to have the gradients returned from replicas, mirrored_strategy.gather() is needed as well.

Besides these, a distributed dataset must be created by using mirrored_strategy.experimental_distribute_dataset. Distributed dataset tries to distribute single batch of data across replicas evenly. An example about these points is included below.

model.fit(), model.predict(),and etc… run in distributed manner automatically just because they’ve already handled everything mentioned above for you.

Example codes:

mirrored_strategy = tf.distribute.MirroredStrategy()
print(f'using distribution strategy\nnumber of gpus:{mirrored_strategy.num_replicas_in_sync}')


#create distributed dataset
ds = mirrored_strategy.experimental_distribute_dataset(dataset)

#make variables mirrored
with mirrored_strategy.scope():

def step_fn(pre_images):
  with tf.GradientTape(watch_accessed_variables=False) as tape:
       outputs = resnet50(pre_images)[:,0:1]
  return tf.squeeze(tape.batch_jacobian(outputs, pre_images))

#define distributed step function using strategy.run and strategy.gather
def distributed_step_fn(pre_images):
  per_replica_grads = mirrored_strategy.run(step_fn, args=(pre_images,))
  return mirrored_strategy.gather(per_replica_grads,0)

#loop over distributed dataset with distributed_step_fn
for result in map(distributed_step_fn,ds):

Answered By – Laplace Ricky

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published