In [1]:
import random

In [2]:
n = random.getrandbits(1024)
#Just choosing a random 1024 bit number

In [3]:
m = random.getrandbits(1008)
m = 2<<1008
m = m + random.getrandbits(1008)
#This is the hidden message we wish to find.

In [4]:
B = 2**1008
B2 = 2*B
B3 = 3*B
ocalls=0

In [5]:
def ceildiv(x,y):   #ceildiv(a,b) = ceil(a/b)
return x//y + (x%y != 0)
def floordiv(x,y):  #floordiv(a,b) = floor(a/b)
return x//y

global ocalls
ocalls+=1
if y >= B2 and y < B3:
return True
else:
return False
#The usual PKSC#1 padding does more checks (there must be at least 8 padded bytes etc.)
#This is a simpler version, that also increase the likelihood of hitting a properly padded string


Finding an s¶

The block findS(smin) starts from smin, and keeps trying out various s's until it finds one that makes the padCheck(m*s) accept

In [6]:
def findS(smin):
global m
global n
si = smin
while(True):
si+=1
mi = (m*si)%n
return si
break


Figuring out the range of r's¶

Once we have found an $s$, and if we are assuming that $a \leq m \leq b$, then we know that $2B \leq ms - rn \leq 3B-1$ for some positive integer $r$. This gives the bound $$\frac{bs - 2B}{n} \geq \frac{ms - 2B}{n} \geq r \geq \frac{ms-3B+1}{n} \geq \frac{as - 3B+1}{n}$$ The function findrRanges(s,a,b) returns the lower and upper bound based on the above formula.

In [7]:
def findrRanges(s,a,b):
rmin = ceildiv((a*s)-B3+1,n)
rmax = floordiv((b*s) - B2,n)
return(rmin,rmax)


An optimisation in case of a single interval¶

In Bleichenbacher's original paper, he performs an optimisation in the setting when we have a single interval that is not the initial inverval of $(2B, 3B-1)$. In this case, you find $r \geq \frac{2(bs - 2B)}{n}$, and look for an $s$ in the range $$\frac{2B + rn}{b} \leq s \leq \frac{3B-1 +rn}{a}.$$ Use this $s$ and proceed.

The idea is that this results in a new interval that's no more than half the original interval. You can see the details in Bleichenbacher's paper.

In [8]:
def findS_opt(s,a,b):
global m
global n
global B2
global B3

r = floordiv(2*((b*s) - B2),n)
while(True):
found = False
while(True):
for si in range(ceildiv(B2+(r*n),b),floordiv(B3-1+(r*n),a)+1):
mi = (si*m)%n
found=True
return(si)
r+=1


The attack¶

If we already knew that $a \leq m \leq b$ and that $2B \leq ms-rn \leq 3B$, then we get that $$\frac{2B + rn}{s} \leq m \leq \frac{3B + rn -1}{s},$$ and we can intersect this with the old interval $(a,b)$. The code below basically find an $s$ and updates the interval, until it ends up with a single interval of length $1$.

In [9]:
def Attack():
global ocalls,n,B2,B3,m
si = ceildiv(n,B3)-1
newM = set([])
ocalls=0
while(True):
# First find an si
if(len(newM)>1 or (B2,B3-1) in newM):
si = findS(si)
elif(len(newM)==1): #use the optimised version in the special case
(a,b) = newM.pop()
si = findS_opt(si,a,b)

#update the intervals
newMM = set([])
for (a,b) in newM:
(r1,r2) = findrRanges(si,a,b)
for r in range(r1,r2+1):
aa = ceildiv(B2 + (r*n),si)
bb = floordiv(B3 - 1 + (r*n),si)
newa = max(a,aa)
newb = min(b,bb)
if newa <= newb:
if len(newMM)>0:
newM = newMM
else:
print("Something went wrong!")
exit(-1)
if len(newM) == 1:
(a,b)= newM.pop()
if a==b:
print("Oracle calls:", ocalls)
return a

In [10]:
%%time
guess = Attack()
if m==guess:
print("Found m: ",m)

Oracle calls: 4506
Found m:  5971331115703357078891961541342638949664044605888442834663177475828153859701294020291819323895945632313357163311109747139028146977957416417384602481528503116951057112305786705469189482505678846379821056382535831302517472639355811058140760853748105564413984688674138740815451264447663254207534472088145088
CPU times: user 69.5 ms, sys: 3.83 ms, total: 73.3 ms
Wall time: 206 ms