Issue
I have a matrix (of vectors) X with shape [3,4], and I want to calculate the dot product between each pair of vectors (X[1].X[1]) and (X[1].X[2])…etc.
I saw a cosine similarity code were they use
tf.reduce_sum(tf.multyply(X, X),axis=1)
to calculate the dot product between the vectors in a matrix of vectors.However, this result in only calculates the dot product between (X[i], X[i]).
I used tf.matmul(X, X, transpose_b=True) which calculate the dot product between every two vectors but I am still confused why tf.multiply didn’t do this I think the problem with my code.
the code is:
data=[[1.0,2.0,4.0,5.0],[0.0,6.0,7.0,8.0],[8.0,1.0,1.0,1.0]]
X=tf.constant(data)
matResult=tf.matmul(X, X, transpose_b=True)
multiplyResult=tf.reduce_sum(tf.multiply(X,X),axis=1)
with tf.Session() as sess:
print('matResult')
print(sess.run([matResult]))
print()
print('multiplyResult')
print(sess.run([multiplyResult]))
The output is:
matResult
[array([[ 46., 80., 19.],
[ 80., 149., 21.],
[ 19., 21., 67.]], dtype=float32)]
multiplyResult
[array([ 46., 149., 67.], dtype=float32)]
I would appreciate any advise
Solution
tf.multiply(X, Y)
or the *
opperator does element-wise multiplication so that
[[1 2] [[1 3] [[1 6]
[3 4]] . [2 1]] = [6 4]]
wheras tf.matmul
does matrix multiplication so that
[[1 0] [[1 3] [[1 3]
[0 1]] . [2 1]] = [2 1]]
using tf.matmul(X, X, transpose_b=True)
means that you are calculating X . X^T
where ^T
indicates the transposing of the matrix and .
is the matrix multiplication.
tf.reduce_sum(_, axis=1)
takes the sum along 1st axis (starting counting with 0) which means you are suming the rows:
tf.reduce_sum([[a b], [c, d]], axis=1) = [a+b, c+d]
This means that:
tf.reduce_sum(tf.multiply(X, X), axis=1) = [X[1].X[1], ..., X[n].X[n]]
so that is the one you want if you only want the norms of each rows. On the other hand
tf.matmul(X, X, transpose_b=True) = [[ X[1].X[1], X[1].X[2], ..., X[1].X[n]],
[X[2].X[1], ..., X[2].X[n]],
...
[X[n].X[1], ..., X[n].X[n]]
so that is what you need if you want the similarity between all pairs of rows.
Answered By – patapouf_ai
This Answer collected from stackoverflow, is licensed under cc by-sa 2.5 , cc by-sa 3.0 and cc by-sa 4.0