from random import *
import numpy as np
from math import *

infinity = 10000



# forward part


def dec(ch,n):    
    l = len(ch)
    acc = 0
    for i in range(l):
        if ch[i]==1:
            acc = acc + 2**(n-i-1)        
    return acc


def bin(elem,n):
    """Convertit un nombre en binaire"""
    q = -1
    res = [0 for i in range(n)]
    i = 1
    while q != 0:
        q = elem // 2
        r = elem % 2
        res[n-i] =  r
        elem = q
        i+=1
    return res



def xorb(a,b):
    return 1 if a != b else 0

def xor(e1,e2,w):
    e1b,e2b  = bin(e1,w),bin(e2,w)
    d = dec([xorb(e1b[j],e2b[j]) for j in range(w)],w)
    return d


def forward(H_hat,x,message):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    path = [[0 for _ in range(len(message))] for _ in range(len(x))]
    nbblock = len(message)
    wght = [infinity for _ in range(int(2**h))] 
    wght[0]=0    
    newwght = [0 for _ in range(int(2**h))]
    rho = [1 for _ in range(len(x))]
    indx,indm = 0,0

    for i in range(nbblock): # pour chaque bit du message
        for j in range(w):   # pour chaque colonne de H_hat
            for k in range(nbblock): # pour chaque ligne de H 
                w0 = wght[k] + x[indx]*rho[indx]
                w1 = wght[xor(k,H_hat[j],w)] + (1-x[indx])*rho[indx]
                path[indx][k] = 1 if w1 < w0 else 0
                newwght[k] = min(w0,w1)
            indx +=1
            wght = [t for t in newwght]

        for j in range(int(2**(h-1))):   # pour chaque colonne de H
            wght[j] = wght[2*j + message[indm]]
        wght = wght[:int(pow(2,h-1))] + [infinity for _ in range(int(pow(2,h)-pow(2,h-1)))]
        indm +=1



    return path


def backward(H_hat,x,message,path):
    (h,w) = int(log(max(H_hat),2))+1, len(H_hat)
    state,indx,indm = message[len(message)-1],len(x)-1,len(message)-2
    # l'initialisation de state n'est pas optimale...
    nbblock = len(message)
    y=[0 for _ in range(len(x))]
    for i in range(nbblock):
        l = range(w)
        l.reverse()
        for j in l:   # pour chaque colonne de H_hat
            y[indx] = path[indx][state]
            state = xor(state,y[indx]*H_hat[j],w)
            indx -=1

        state = 2*state + message[indm]
        indm -=1 
    return y

    
 

def stc(x,message):
    H_hat = [3,2]
    # reflechir a une optimisation du la matrice H_hat  
    path = forward(H_hat,x,message)
    #print path
    return backward(H_hat,x,message,path)



x = [randint(0,1) for _ in range(50000)]
message = [randint(0,1) for _ in range(25000)]

print stc(x,message)