from random import *
import numpy as np
from math import *
import gc

infinity = 10



# forward part


def dec(ch,n):    
    l = len(ch)
    acc = 0
    for i in xrange(l):
        if ch[i]==1:
            acc = acc + 2**(n-i-1)        
    return acc


def bin(elem,n):
    """Convertit un nombre en binaire"""
    q = -1
    res = [0 for i in xrange(n)]
    i = 1
    while q != 0:
        q = elem // 2
        r = elem % 2
        res[n-i] =  r
        elem = q
        i+=1
    return res



def xorb(a,b):
    return 1 if a != b else 0

def xor(e1,e2,h):
    e1b,e2b  = bin(e1,h),bin(e2,h)
    d = dec([xorb(e1b[j],e2b[j]) for j in xrange(h)],h)
    return d

def lit(d,(indx,indy)):
    if (indx,indy) in d :
        return d[(indx,indy)]
    else :
        return 0




def forward(H_hat,x,message,lnm,rho):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    path = dict() 
    nbblock = len(x)/w
    wght = [infinity for _ in xrange(int(2**h))] 
    wght[0]=0    
    newwght = [0 for _ in xrange(int(2**h))]
#    rho = 1
#    rho= [1 for _ in xrange(len(x))]
    indx,indm = 0,0
    i=0
    while i < nbblock: # pour chaque bit du message
        for j in xrange(w):   # pour chaque colonne de H_hat
            print indx, "en entrant",wght
            k = 0
            while k < int(2**h): # pour chaque ligne de H
                w0 = wght[k] + x[indx]*rho[indx]
                w1 = wght[xor(k,H_hat[j],h)] + (1-x[indx])*rho[indx]
                if w1 < w0 :
                    path[(indx,k)] = 1 
                else : 
                    if (indx,k) in path:
                        del path[(indx,k)]
                newwght[k] = min(w0,w1)
                k +=1 
            indx +=1
            wght = [t for t in newwght]
            print " apres calcul",wght

        for j in xrange(int(2**(h-1))):   # pour chaque colonne de H
            wght[j] = wght[2*j + message[indm]]
        wght = wght[:int(pow(2,h-1))] + [infinity for _ in xrange(int(pow(2,h)-pow(2,h-1)))]
        indm +=1
        i +=1
    # juste sur le modulo
    reste_a_faire = len(x) % w 
    if reste_a_faire != 0 :
        for j in xrange(reste_a_faire):   # pour chaque colonne de H_hat qui reste
            #print indx, "en entrant",wght
            k = 0
            while k < int(2**h): # pour chaque ligne de H
                w0 = wght[k] + x[indx]*rho[indx]
                w1 = wght[xor(k,H_hat[j],h)] + (1-x[indx])*rho[indx]
                if w1 < w0 :
                    path[(indx,k)] = 1 
                else : 
                    if (indx,k) in path:
                        del path[(indx,k)]
                newwght[k] = min(w0,w1)
                k +=1 
            indx +=1
            wght = [t for t in newwght]
            #print " apres calcul",wght

        for j in xrange(int(2**(h-1))):   # pour chaque colonne de H
            wght[j] = wght[2*j + message[indm]]
        wght = wght[:int(pow(2,h-1))] + [infinity for _ in xrange(int(pow(2,h)-pow(2,h-1)))]
        indm +=1
    #fin du modulo

    start = np.argmin(wght)
    return (start,path)


def backward(start,H_hat,x,message,lnm,path):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    indx,indm = len(x)-1,lnm-1
    state = 2*start + message[indm]
    indm -=1
    # l'initialisation de state n'est pas optimale...
    nbblock = len(x)/w
    y=np.zeros(len(x))
    i=0
    # sur le reste 
    reste_a_faire = len(x) % w 
    if reste_a_faire != 0 :
        l = range(reste_a_faire)
        l.reverse()
        for j in l:   # pour chaque colonne qui reste a faire
            y[indx] = lit(path,(indx,state))
            state = xor(state,y[indx]*H_hat[j],h)
            indx -=1
        state = 2*state + message[indm]
        indm -=1 
    # fin du reste 
    while i < nbblock:
        l = range(w)
        l.reverse()
        for j in l:   # pour chaque colonne de H_hat
            y[indx] = lit(path,(indx,state))
            state = xor(state,y[indx]*H_hat[j],h)
            indx -=1
        state = 2*state + message[indm]
        indm -=1 
        i +=1
    return [int(t) for t in y]

    
 



def trouve_H_hat(n,m,h):
    assert h ==7 
    alpha = float(n)/m
    assert alpha >= 1 
    index = min(int(alpha),9)
    matr = {
        2 : [71,109],
        3 : [95, 101, 121],
        4 : [81, 95, 107, 121],
        5 : [75, 95, 97, 105, 117],
        6 : [73, 83, 95, 103, 109, 123],
        7 : [69, 77, 93, 107, 111, 115, 121],
        8 : [69, 79, 81, 89, 93, 99, 107, 119],
        9 : [69, 79, 81, 89, 93, 99, 107, 119, 125]
        }
    mat = [] 
    if index not in matr:
         mat = matr[2]
    else :
        mat = matr[index]
    return mat


def stc(x,rho,message):
    lnm = len(message)
    mat = trouve_H_hat(len(x),len(message),7)
    x_b = [i for i in x]
    mat=[3,1]
    (start,path) = forward(mat,x_b,message,lnm,rho)
    return (x_b,backward(start,mat,x_b,message,lnm,path),mat)





def nbdif(x,y):
    r,it = 0,0
    l = len(y)
    while it < l :
        if x[it] != y[it] :
            r +=1
        it += 1
    return float(r)/l 
        





def prod(H_hat,lnm,y):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    i=0
    H =[]
    V=[0 for _ in range(len(y))]
    sol=[]
    Vp =[]
    while i < lnm: # pour chaque ligne 
        V=[0 for _ in range(len(y))]    
        k = max([(i-h+1)*w,0])
        dec = max([i-h+1,0])
        for j in xrange(min([i+1,h])): #nbre de blocks presents sur la ligne i
            for l in xrange(w): # pour chaque collone de H_hat
                if k < len(y):
                    V[k] = bin(H_hat[l],h)[h-i-1+j+dec]
                    k+=1
        sol.append(np.dot(np.array(V),np.array(y)))
        i+=1
        Vp +=[V]
    Vp =  np.array(Vp)
    print Vp
    return np.dot(Vp,np.array(y))

    
def equiv(x,y): 
    lx = len(x)
    assert lx == len(y)
    i=0
    while i < lx :
        if x[i] % 2 != y[i]%2 : 
            return False
        i += 1
    return True
        

################
"""
x = [randint(0,1) for _ in xrange(65000)]
rho = [randint(1,9) for _ in xrange(65000)]
message = [randint(0,1) for _ in xrange(26000)]
"""
x = [0, 0, 1, 0, 0, 1]
rho = [1 for _ in xrange(100)]
message = [0, 1, 0, 1]



(x_b,y,H_hat) = stc(x,rho,message)
print "message", message
print "y", y
print "x", x
print "H_hat", H_hat



# x_b est la sous partie de x qui va etre modifiee
# y est le vecteur des bits modifies
# H_hat est la sous-matrice retenue qui est embarquee dans H  

print "avec stc :", nbdif(x_b,y)
print "sans stc :", nbdif(message,x[:len(message)])

#print message
#print x
#print rho
#print y
message2 = [x%2 for x in prod(H_hat,len(message),y)]
print "messag2", message2

#print message2 
print equiv(message,message2)