#-*-coding:utf-8-*-

#############################################
# Script réalisé dans le cadre du papier
# concernant l'utilisation de plusieurs
# fonctions chaotiques (pas seulement la 
# négation vectorielle) pour le tatouage.
#
# On utilise 3 fonctions différentes, on tatoue
# dans le domaine ondelette.
#############################################
import pywt

from numpy import * 

#from outilsBase import conversion, getBit, setBit
from random import *
from copy import deepcopy
import Image as im
from ImageChops import difference
from suite import *
from attaque import Attaque

def matrice_to_bits(matrice):
    '''
    Renvoie la matrice des écritures binaires de matrice.
    Les coefficients floants deviennent des chaînes (str) de bits.
    '''
    (m,n) = matrice.shape
    retour = []
    for l in range(m):
        ligne = []
        for c in range(n):
            ligne.append(conversion(str(matrice[l,c]),2))
        retour.append(ligne)
    return retour



def matrice_lscs(matrice,lscs):
    '''
    Matrice est une liste de listes, lscs est une liste.

    A partir d'une matrice de coefficients binaires, vus comme
    des chaines de caractères, extrait les bits dont les positions
    sont données par la liste lscs.
    
    Dans la liste lscs, un entier positif signifie devant la virgule,
    un entier négatif signifie derrière. Le premier bit devant la 
    virgule est le bit 1, le premier bit derrière la virgule est le -1.

    Le retour est une liste de bits (entiers).
    '''
    m,n = len(matrice), len(matrice[0])
    retour = []
    for l in range(m):
        for c in range(n):
            #num = str(round(float(matrice[l][c]),2))
            num = matrice[l][c]

            if '.' not in num:
                ent,dec = num.replace('-',''),'0'
            else:
                ent,dec = num.replace('-','').split('.')
            ent,dec = list(ent),list(dec)
            ent.reverse()
            for lsc in lscs:
                if lsc > 0 and len(ent)>=lsc:
                    retour.append(ent[lsc-1])
                elif lsc<0 and len(dec)>=abs(lsc):
                    retour.append(dec[abs(lsc)-1])
                else:
                    retour.append('0')
    return [int(k) for k in retour][:-3]


def embarque(liste,matrice,lscs):
    m,n = len(matrice), len(matrice[0])
    retour = []
    cpt = 0
    for l in range(m):
        for c in range(n):
            if '-' in matrice[l][c]:
                signe = '-'
            else:
                signe = ''
            if '.' not in matrice[l][c]:
                ent,dec = matrice[l][c].replace('-',''),'0'
            else:
                ent,dec = matrice[l][c].replace('-','').split('.')
            ent,dec = list(ent),list(dec)
            ent.reverse()
            maximum = max([abs(k) for k in lscs])
            ent = list(''.join(ent).zfill(maximum+2))
            dec = list(''.join(dec).zfill(maximum+2))
            print dec
            for lsc in lscs:
                if lsc > 0:
                    ent[lsc-1] = str(liste[cpt])
                else:
                    dec[abs(lsc)-1] = str(liste[cpt])
                cpt += 1
            ent.reverse()
            ent = ''.join(ent)
            dec = ''.join(dec)
            print ent+'.'+dec



def f(L):
    assert len(L)%4 == 1
    n = len(L)/4
    retour = [int(not L[k]) for k in range(n)]
    retour.extend([L[k-n] for k in range(n,2*n)])
    retour.extend([int(L[k-2*n])*int(not L[k+1]) for k in range(2*n,4*n)])
    retour.extend([int(not L[2*n])])
    return retour




def f(L):
    retour = [int(not L[0])]
    retour.extend([L[k-1] for k in range(1,len(L))])
    return retour    




def f(x):
    rl=[]
    k = 1
    for el in x:
        if k%2 != 0 :
            rl.append(int(not el))
        else:
            rl.append(0 if sum(x[k-2:k])%2==0 else 1 )
        k+=1
    return rl




def f(i,x):
    r = 0
    if i%2 == 0 :
        r = int(not x[i])
    else:
        r = 0 if sum(x[i-1:i+1])%2==0 else 1 
    return r



def f(i,L):
    return int(not L[i])



def embarque2(liste,matrice,lscs):
    
    m,n = len(matrice), len(matrice[0])
    retour = deepcopy(matrice)
    cpt = 0
    for l in range(m):
        for c in range(n):
            for lsc in lscs:
                try:
                    retour[l,c] = setBit(str(retour[l,c]),lsc, liste[cpt])
                    cpt += 1
                except:
                    pass
    return retour

def PSNR(mat1,mat2):
    (li,co) = mat1.shape
    '''
    Retourne le PSNR entre deux images.
    '''
    from math import log
    densite = 2**8-1
    eqm = 0
    for k in range(li):
        for l in range(co):
            eqm+=(mat1[k][l] - mat2[k][l])**2
    r= float(eqm)/li/co
    if r !=0:
        return 10*log(densite**2/r,10)
    else:
        return "Infini"


    
def diff(mat1,mat2):
    lm=len(mat1)
    cpt = 0
    for k in range(lm):
        if mat1[k] != mat2[k] :
            cpt +=1
            #print mat1[k],mat2[k] 
    return float(100)*cpt/lm

def arrondi_en_entier(mat):
    (l,c) = mat.shape
    return array([[round(mat[i][j]) for j in range(c)]for i in range(l)])
        
        
            


def experience(nom_fichier):
    str_acc = "nom;"+nom_fichier

    initiale_image = im.open(nom_fichier+".png")
    (taillex,tailley) = initiale_image.size
    initiale_matrice =  array(initiale_image.getdata()).reshape((taillex,tailley))

    #recup 
    hote_dct = Dct(nom_fichier+".png",idx_dct)
    hote_dct2 = Dct(nom_fichier+".png",idx_dct)

    initiale_dct= deepcopy(hote_dct.get_ext_dct())
    

    # recup bits insignifiants 
    initiale_lscs = matrice_lscs(matrice_to_bits(initiale_dct),LSCs)
    initiale_lscs_a_iterer = deepcopy(initiale_lscs)

    # Nb LSCs 
    lm = len(initiale_lscs)
    # Itérations D2
    ciis = CIIS(random(),random(),random()/2,lm-1,lm)._genere()    
    for k in range(4*lm):
        strat = ciis.next()
        #print strat 
        initiale_lscs_a_iterer[strat] = int(
            f(strat,initiale_lscs_a_iterer))



    # Calcul du nombre de différences dans la matrice DCT


    str_acc += ";#lscs;" + str(lm)

    
    dif_iter = diff(initiale_lscs,initiale_lscs_a_iterer)
    str_acc += ";#diff it ; " + str(dif_iter)
    
    
    # Tatouage de la matrice")
    marquee_ext_dct = embarque2(initiale_lscs_a_iterer,
                           initiale_dct,LSCs)    
    
    # reconstruction de la matrice complete
    hote_dct2.set_ext_dct(marquee_ext_dct)
    marquee_matrice = hote_dct2._matrice
    marquee_matrice = arrondi_en_entier(marquee_matrice)

    
    #sauvegarde de l'image
    (l,h)=marquee_matrice.shape
    at = im.new("L",(l,h))
    at.putdata(marquee_matrice.flatten())
    at.save(nom_fichier+"_bis.png")
    at.show()


#listing = os.listdir(path_cover)



    # psnr
    str_acc += ";PSNR;" + str(PSNR(initiale_matrice,marquee_matrice))

    
    nbexp=10
    # matrices attaquees par rotation
    attaquee_matrice = [] 
    print "rotations"
    for k in range(nbexp):
        at  = Attaque.rotation(marquee_matrice,
                               3*(k+1),
                               1)
        attaquee_matrice += [at]
        
    # matrices attaquees compression jpeg
    print "compression"
    for k in range(nbexp):
        at  = Attaque.jpeg(marquee_matrice,
                           nom_fichier,
                           100-3*(k+1))
        attaquee_matrice +=  [at]
                                          

    # matrices attaquees compression jp2000
    print "compression jp2000"
    for k in range(nbexp):
        at  = Attaque.jp2(marquee_matrice,
                           nom_fichier,
                           100-10*(k+1))
        attaquee_matrice +=  [at]
        


    print "decoupage"        
    # matrices attaquees decoupage
    for k in range(nbexp):
        t = int(0.1*(k+1)*min(taillex,tailley))
        at  = Attaque.decoupage(marquee_matrice,
                                t,
                                (0,0))
        attaquee_matrice +=  [at]

    print "flou"            
    # matrices attaquees flou
    for k in range(nbexp):
        t=0.6+0.08*(k+1)
        at  = Attaque.flou(marquee_matrice,t)
        attaquee_matrice +=  [at]

    print "contraste"            
    # matrices attaquees contrast
    for k in range(nbexp):
        t=0.6+0.08*(k+1)
        at  = Attaque.contraste(marquee_matrice,t)
        attaquee_matrice +=  [at]


    # matrices attaquees redimensionnement
    print "dimensionnement"
    attaquee_matrice += [Attaque.redimensionnement(marquee_matrice,
                                                   0.75,
                                                   1)]
    attaquee_matrice += [Attaque.redimensionnement(marquee_matrice,
                                                   1.5,
                                                   1)]
    



    print  "nombre d'attaques",len(attaquee_matrice)
    
    # evaluation des attaques
    _,(marquee_H2,marquee_V2,marquee),(_,_,_)  = \
        pywt.wavedec2(marquee_matrice,
                      'db1',
                      level = 2)
    marquee_lscs = matrice_lscs(matrice_to_bits(marquee_matrice),LSCs)

    c=0
    for am in attaquee_matrice:
        c +=1
        attaquee = am
        
        attaquee_lscs = matrice_lscs(matrice_to_bits(attaquee),LSCs)
        d = diff(marquee_lscs,attaquee_lscs)
        str_acc += ";" + str(d)



    return str_acc

    
"""
for j in sample(range(1,3001),2):    
    print experience("../images/"+str(j),3)

"""
print experience("../images/"+str(3622),3)
