How can I convert NHWC to NCHW in python for deepstream

Issue

I have a TensorFlow Keras model which is stored in .pb format and from .pb format I am converting the model to .onnx format using the tf2onnx model

!python -m tf2onnx.convert --saved-model model.pb --output model.onnx 

now after converting I see that my input layer is in NHWC format and I need to convert the same to NCHW, to achieve that I am using

!python -m tf2onnx.convert --saved-model model.pb --output model_3.onnx --inputs-as-nchw input0:0

which is still giving me the same output as NHWC
I have to consume the above model in NVIDIA Deepstream which only accepts NCHW format.

I found this link which talks about the transpose of the input layer, but unfortunately, that is also not working.
Convert between NHWC and NCHW in TensorFlow

#import tensorflow as tf
images_nhwc = tf.compat.v1.placeholder(tf.float32, [1, 200, 300, 3])  
# input batch
out = tf.transpose(images_nhwc, [0, 3, 1, 2])
#print(out.get_shape())
model.build(out.get_shape())

enter image description here
It would be really helpful if some experts can share their thoughts on how to convert NHWC to NCHW

Solution

I found the solution.
I had to take the latest code of tf2onnx.convert.from_keras. I took the main branch from tf2onnx

!pip install --force-reinstall  git+https://github.com/onnx/[email protected]
!pip show tf2onnx
!pip freeze | grep tf2onnx

once that was done I was able to load the latest functionality and updated code at
https://github.com/onnx/tensorflow-onnx/tree/e896723e410a59a600d1a73657f9965a3cbf2c3b .

Below is the code I used to convert my model from .pb to .onnx along with NHWC to NCHW.

# give the list of *inputs* which should be converted and returned *as nchw*
_INPUT = model.input.name

model_proto, external_tensor_storage = tf2onnx.convert.from_keras(model, inputs_as_nchw=[_INPUT])

The biggest catch about the above code was [_INPUT] which was suppose to be a list and I was able find this information in the test cases.

Answered By – Sovik Gupta

This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0

Leave a Reply

(*) Required, Your email will not be published