In [1]:
import random
In [2]:
n = random.getrandbits(1024)
#Just choosing a random 1024 bit number
In [3]:
m = random.getrandbits(1008)
m = 2<<1008
m = m + random.getrandbits(1008)
#This is the hidden message we wish to find.
In [4]:
B = 2**1008
B2 = 2*B
B3 = 3*B
In [5]:
def ceildiv(x,y):   #ceildiv(a,b) = ceil(a/b)
    return x//y + (x%y != 0)
def floordiv(x,y):  #floordiv(a,b) = floor(a/b)
    return x//y

def padCheck(y):
    global ocalls
    if y >= B2 and y < B3:
        return True
        return False
#The usual PKSC#1 padding does more checks (there must be at least 8 padded bytes etc.)
#This is a simpler version, that also increase the likelihood of hitting a properly padded string

Finding an s

The block findS(smin) starts from smin, and keeps trying out various s's until it finds one that makes the padCheck(m*s) accept

In [6]:
def findS(smin):
    global m
    global n
    si = smin
        mi = (m*si)%n
        if padCheck(mi):
            return si

Figuring out the range of r's

Once we have found an $s$, and if we are assuming that $a \leq m \leq b$, then we know that $2B \leq ms - rn \leq 3B-1$ for some positive integer $r$. This gives the bound $$ \frac{bs - 2B}{n} \geq \frac{ms - 2B}{n} \geq r \geq \frac{ms-3B+1}{n} \geq \frac{as - 3B+1}{n} $$ The function findrRanges(s,a,b) returns the lower and upper bound based on the above formula.

In [7]:
def findrRanges(s,a,b):
    rmin = ceildiv((a*s)-B3+1,n)
    rmax = floordiv((b*s) - B2,n)

An optimisation in case of a single interval

In Bleichenbacher's original paper, he performs an optimisation in the setting when we have a single interval that is not the initial inverval of $(2B, 3B-1)$. In this case, you find $r \geq \frac{2(bs - 2B)}{n}$, and look for an $s$ in the range $$ \frac{2B + rn}{b} \leq s \leq \frac{3B-1 +rn}{a}. $$ Use this $s$ and proceed.

The idea is that this results in a new interval that's no more than half the original interval. You can see the details in Bleichenbacher's paper.

In [8]:
def findS_opt(s,a,b):
    global m
    global n
    global B2
    global B3
    r = floordiv(2*((b*s) - B2),n)
        found = False
            for si in range(ceildiv(B2+(r*n),b),floordiv(B3-1+(r*n),a)+1):
                mi = (si*m)%n
                if padCheck(mi):
            if not found:

The attack

If we already knew that $a \leq m \leq b$ and that $2B \leq ms-rn \leq 3B$, then we get that $$ \frac{2B + rn}{s} \leq m \leq \frac{3B + rn -1}{s}, $$ and we can intersect this with the old interval $(a,b)$. The code below basically find an $s$ and updates the interval, until it ends up with a single interval of length $1$.

In [9]:
def Attack():
    global ocalls,n,B2,B3,m
    si = ceildiv(n,B3)-1
    newM = set([])
        # First find an si
        if(len(newM)>1 or (B2,B3-1) in newM):
            si = findS(si)
        elif(len(newM)==1): #use the optimised version in the special case
            (a,b) = newM.pop()
            si = findS_opt(si,a,b)
        #update the intervals
        newMM = set([])
        for (a,b) in newM:
            (r1,r2) = findrRanges(si,a,b)
            for r in range(r1,r2+1):
                aa = ceildiv(B2 + (r*n),si)
                bb = floordiv(B3 - 1 + (r*n),si)
                newa = max(a,aa)
                newb = min(b,bb)
                if newa <= newb:
        if len(newMM)>0:
            newM = newMM
            print("Something went wrong!")
        if len(newM) == 1:
            (a,b)= newM.pop()
            if a==b:
                print("Oracle calls:", ocalls)
                return a
In [10]:
guess = Attack()
if m==guess:
    print("Found m: ",m)
Oracle calls: 4506
Found m:  5971331115703357078891961541342638949664044605888442834663177475828153859701294020291819323895945632313357163311109747139028146977957416417384602481528503116951057112305786705469189482505678846379821056382535831302517472639355811058140760853748105564413984688674138740815451264447663254207534472088145088
CPU times: user 69.5 ms, sys: 3.83 ms, total: 73.3 ms
Wall time: 206 ms