]> AND Private Git Repository - these_gilles.git/blob - THESE/codes/bm3D/BM3D/IDDBM3D/BM3DDEB_init.m
Logo AND Algorithmique Numérique Distribuée

Private GIT Repository
11 oct
[these_gilles.git] / THESE / codes / bm3D / BM3D / IDDBM3D / BM3DDEB_init.m
1 function [ISNR, y_hat_RI,y_hat_RWI,zRI] = BM3DDEB_init(experiment_number, y, z, v, sigma)
2 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
3 %
4 % Copyright © 2008 Tampere University of Technology. All rights reserved.
5 % This work should only be used for nonprofit purposes.
6 %
7 % AUTHORS:
8 %     Kostadin Dabov, email: kostadin.dabov _at_ tut.fi
9 %     Alessandro Foi
10 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
11 %
12 %  This function implements the image deblurring method proposed in:
13 %
14 %  [1] K. Dabov, A. Foi, V. Katkovnik, and K. Egiazarian, "Image 
15 %   restoration  by sparse 3D transform-domain collaborative filtering," 
16 %   Proc SPIE Electronic Imaging, January 2008.
17 %
18 %  FUNCTION INTERFACE:
19 %
20 %  [PSNR, y_hat_RWI] = BM3DDEB(experiment_number, test_image_name)
21 %  
22 %  INPUT:
23 %   1) experiment_number: 1 -> PSF 1, sigma^2 = 2
24 %                         2 -> PSF 1, sigma^2 = 8
25 %                         3 -> PSF 2, sigma^2 = 0.308
26 %                         4 -> PSF 3, sigma^2 = 49
27 %                         5 -> PSF 4, sigma^2 = 4
28 %                         6 -> PSF 5, sigma^2 = 64
29 %         
30 %   2) test_image_name:   a valid filename of a grayscale test image
31 %
32 %  OUTPUT:
33 %   1) ISNR:              the output improvement in SNR, dB
34 %   2) y_hat_RWI:         the restored image
35 %
36 %  ! The function can work without any of the input arguments, 
37 %   in which case, the internal default ones are used !
38 %
39 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
40
41 %%%% Fixed regularization parameters (obtained empirically after a rough optimization)
42 Regularization_alpha_RI = 4e-4;
43 Regularization_alpha_RWI = 5e-3;
44
45 %%%% Experiment number (see below for details, e.g. how the blur is generated, etc.)
46 if (exist('experiment_number') ~= 1)
47     experiment_number = 3; % 1 -- 6
48 end
49
50 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
51 %%%% Select a single image filename (might contain path)
52 %%%%
53 % if (exist('test_image_name') ~= 1)
54 %     test_image_name = [
55 % %        'Lena512.png'
56 %         'Cameraman256.png'
57 % %        'barbara.png'
58 % %        'house.png'
59 %     ];
60 % end
61
62 %%%% Select 2D transforms ('dct', 'dst', 'hadamard', or anything that is listed by 'help wfilters'):
63 transform_2D_HT_name      = 'dst'; %% 2D transform (of size N1 x N1) used in Step 1 
64 transform_2D_Wiener_name  = 'dct'; %% 2D transform (of size N1_wiener x N1_wiener) used in Step 2 
65 transform_3rd_dimage_name = 'haar'; %% 1D tranform used in the 3-rd dim, the same for both steps
66
67 %%%% Step 1 (BM3D with collaborative hard-thresholding) parameters:
68 N1                  = 8;   %% N1 x N1 is the block size
69 Nstep               = 3;   %% sliding step to process every next refernece block
70 N2                  = 16;  %% maximum number of similar blocks (maximum size of the 3rd dimensiona of a 3D array)
71 Ns                  = 39;  %% length of the side of the search neighborhood for full-search block-matching (BM)
72 tau_match           = 6000;%% threshold for the block distance (d-distance)
73 lambda_thr2D        = 0;   %% threshold for the coarse initial denoising used in the d-distance measure
74 lambda_thr3D        = 2.9; %% threshold for the hard-thresholding 
75 beta                = 0; %% the beta parameter of the 2D Kaiser window used in the reconstruction
76
77 %%%% Step 2 (BM3D with collaborative Wiener filtering) parameters:
78 N1_wiener           = 8;
79 Nstep_wiener        = 2;
80 N2_wiener           = 16;
81 Ns_wiener           = 39;
82 tau_match_wiener    = 800;
83 beta_wiener         = 0;
84
85 %%%%  Specify whether to print results and display images
86 print_to_screen     = 0;
87
88
89 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
90 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
91 %%%% Note: touch below this point only if you know what you are doing!
92 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
93 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
94
95 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
96 %%%% Make parameters compatible with the interface of the mex-functions
97 %%%%
98
99 [Tfor, Tinv]   = getTransfMatrix(N1, transform_2D_HT_name, 0); %% get (normalized) forward and inverse transform matrices
100 [TforW, TinvW] = getTransfMatrix(N1_wiener, transform_2D_Wiener_name, 0); %% get (normalized) forward and inverse transform matrices
101
102 if (strcmp(transform_3rd_dimage_name, 'haar') == 1),
103     %%% Fast internal transform is used, no need to generate transform
104     %%% matrices.
105     hadper_trans_single_den         = {};
106     inverse_hadper_trans_single_den = {};
107 else
108     %%% Create transform matrices. The transforms are later applied by
109     %%% vector-matrix multiplications
110     for hpow = 0:ceil(log2(max(N2,N2_wiener))),
111         h = 2^hpow;
112         [Tfor3rd, Tinv3rd] = getTransfMatrix(h, transform_3rd_dimage_name, 0);
113         hadper_trans_single_den{h}         = single(Tfor3rd);
114         inverse_hadper_trans_single_den{h} = single(Tinv3rd');
115     end
116 end
117
118 if beta == 0 & beta_wiener == 0
119     Wwin2D = ones(N1_wiener,N1_wiener);
120     Wwin2D_wiener = ones(N1,N1);
121 else
122     Wwin2D        = kaiser(N1, beta) * kaiser(N1, beta)'; % Kaiser window used in the hard-thresholding part
123     Wwin2D_wiener = kaiser(N1_wiener, beta_wiener) * kaiser(N1_wiener, beta_wiener)'; % Kaiser window used in the Wiener filtering part
124 end
125
126 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
127 % %%%% Read an image and generate a blurred and noisy image
128 % %%%%
129 % y = im2double(imread(test_image_name));
130
131 % if experiment_number==1
132 %     sigma=sqrt(2)/255; 
133 %     for x1=-7:7; for x2=-7:7; v(x1+8,x2+8)=1/(x1^2+x2^2+1); end, end; v=v./sum(v(:));
134 % end
135 % if experiment_number==2
136 %     sigma=sqrt(8)/255;
137 %     s1=0; for a1=-7:7; s1=s1+1; s2=0; for a2=-7:7; s2=s2+1; v(s1,s2)=1/(a1^2+a2^2+1); end, end;  v=v./sum(v(:));
138 % end
139 % if experiment_number==3
140 %     BSNR=40; sigma=-1; % if "sigma=-1", then the value of sigma depends on the BSNR
141 %     v=ones(9); v=v./sum(v(:));
142 % end
143 % if experiment_number==4
144 %     sigma=7/255;
145 %     v=[1 4 6 4 1]'*[1 4 6 4 1]; v=v./sum(v(:));  % PSF
146 % end
147 % if experiment_number==5
148 %     sigma=2/255;
149 %     v=fspecial('gaussian', 25, 1.6);
150 % end
151 % if experiment_number==6
152 %     sigma=8/255;
153 %     v=fspecial('gaussian', 25, .4);
154 % end
155
156
157 [Xv, Xh]  = size(y);
158 [ghy,ghx] = size(v);
159 big_v  = zeros(Xv,Xh); big_v(1:ghy,1:ghx)=v; big_v=circshift(big_v,-round([(ghy-1)/2 (ghx-1)/2])); % pad PSF with zeros to whole image domain, and center it
160 V      = fft2(big_v); % frequency response of the PSF
161 % y_blur = imfilter(y, v, 'circular'); % performs blurring (by circular convolution)
162
163 % randn('seed',0);  %%% fix seed for the random number generator
164 % if sigma == -1;   %% check whether to use BSNR in order to define value of sigma
165 %     sigma=sqrt(norm(y_blur(:)-mean(y_blur(:)),2)^2 /(Xh*Xv*10^(BSNR/10))); % compute sigma from the desired BSNR
166 % end
167
168 % %%%% Create a blurred and noisy observation
169 % z = y_blur + sigma*randn(Xv,Xh);
170
171
172 tic;
173 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
174 %%%% Step 1: Final estimate by Regularized Inversion (RI) followed by 
175 %%%% BM3D with collaborative hard-thresholding
176 %%%%
177
178 %%%% Step 1.1. Regularized Inversion
179 RI= conj(V)./( (abs(V).^2) + Regularization_alpha_RI * Xv*Xh*sigma^2); % Transfer Matrix for RI    %% Standard Tikhonov Regularization
180 zRI=real(ifft2( fft2(z).* RI ));   % Regularized Inverse Estimate (RI OBSERVATION)
181
182 stdRI = zeros(N1, N1);
183 for ii = 1:N1,
184     for jj = 1:N1,
185         UnitMatrix = zeros(N1,N1); UnitMatrix(ii,jj)=1;
186         BasisElementPadded = zeros(Xv, Xh); BasisElementPadded(1:N1,1:N1) = Tinv*UnitMatrix*Tinv'; 
187         TransfBasisElementPadded = fft2(BasisElementPadded);
188         stdRI(ii,jj) = sqrt( (1/(Xv*Xh)) * sum(sum(abs(TransfBasisElementPadded.*RI).^2)) )*sigma;
189     end,
190 end
191
192 %%%% Step 1.2. Colored noise suppression by BM3D with collaborative hard-
193 %%%% thresholding 
194
195 y_hat_RI = bm3d_thr_colored_noise(zRI, hadper_trans_single_den, Nstep, N1, N2, lambda_thr2D,...
196     lambda_thr3D, tau_match*N1*N1/(255*255), (Ns-1)/2, sigma, 0, single(Tfor), single(Tinv)',...
197     inverse_hadper_trans_single_den, single(stdRI'), Wwin2D, 0, 1 );
198
199 PSNR_INITIAL_ESTIMATE = 10*log10(1/mean((y(:)-y_hat_RI(:)).^2));
200 ISNR_INITIAL_ESTIMATE = PSNR_INITIAL_ESTIMATE - 10*log10(1/mean((y(:)-z(:)).^2));
201
202 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
203 %%%% Step 2: Final estimate by Regularized Wiener Inversion (RWI) followed
204 %%%% by BM3D with collaborative Wiener filtering
205 %%%%
206
207 %%%% Step 2.1. Regularized Wiener Inversion
208 Wiener_Pilot = abs(fft2(double(y_hat_RI)));   %%% Wiener reference estimate
209 RWI  = conj(V).*Wiener_Pilot.^2./(Wiener_Pilot.^2.*(abs(V).^2) + Regularization_alpha_RWI*Xv*Xh*sigma^2);   % Transfer Matrix for RWI (uses standard regularization 'a-la-Tikhonov')
210 zRWI = real(ifft2(fft2(z).*RWI));   % RWI OBSERVATION
211
212 stdRWI = zeros(N1_wiener, N1_wiener);
213 for ii = 1:N1,
214     for jj = 1:N1,
215         UnitMatrix = zeros(N1,N1); UnitMatrix(ii,jj)=1;
216         BasisElementPadded = zeros(Xv, Xh); BasisElementPadded(1:N1,1:N1) = idct2(UnitMatrix); 
217         TransfBasisElementPadded = fft2(BasisElementPadded);
218         stdRWI(ii,jj) = sqrt( (1/(Xv*Xh)) * sum(sum(abs(TransfBasisElementPadded.*RWI).^2)) )*sigma;
219     end,
220 end
221
222 %%%% Step 2.2. Colored noise suppression by BM3D with collaborative Wiener
223 %%%% filtering
224 y_hat_RWI = bm3d_wiener_colored_noise(zRWI, y_hat_RI, hadper_trans_single_den, Nstep_wiener, N1_wiener, N2_wiener, ...
225      0, tau_match_wiener*N1_wiener*N1_wiener/(255*255), (Ns_wiener-1)/2, 0, single(stdRWI'), single(TforW), single(TinvW)',...
226      inverse_hadper_trans_single_den, Wwin2D_wiener, 0, 1, single(ones(N1_wiener)) );
227
228 elapsed_time = toc;
229
230
231 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
232 %%%% Calculate the final estimate's PSNR and ISNR, print them, and show the
233 %%%% restored image
234 %%%%
235 PSNR = 10*log10(1/mean((y(:)-y_hat_RWI(:)).^2));
236 ISNR = PSNR - 10*log10(1/mean((y(:)-z(:)).^2));
237
238 if print_to_screen == 1
239 fprintf('Image: %s, Exp %d, Time: %.1f sec, PSNR-RI: %.2f dB, PSNR-RWI: %.2f, ISNR-RWI: %.2f dB\n', ...
240     test_image_name, experiment_number, elapsed_time, PSNR_INITIAL_ESTIMATE, PSNR, ISNR);
241     figure,imshow(z);
242     figure,imshow(double(y_hat_RWI));
243 end
244
245 return;
246
247
248 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
249 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
250 % Some auxiliary functions 
251 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
252 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
253
254
255
256 function [Tforward, Tinverse] = getTransfMatrix (N, transform_type, dec_levels)
257 %
258 % Create forward and inverse transform matrices, which allow for perfect
259 % reconstruction. The forward transform matrix is normalized so that the 
260 % l2-norm of each basis element is 1.
261 %
262 % [Tforward, Tinverse] = getTransfMatrix (N, transform_type, dec_levels)
263 %
264 %  INPUTS:
265 %
266 %   N               --> Size of the transform (for wavelets, must be 2^K)
267 %
268 %   transform_type  --> 'dct', 'dst', 'hadamard', or anything that is 
269 %                       listed by 'help wfilters' (bi-orthogonal wavelets)
270 %                       'DCrand' -- an orthonormal transform with a DC and all
271 %                       the other basis elements of random nature
272 %
273 %   dec_levels      --> If a wavelet transform is generated, this is the
274 %                       desired decomposition level. Must be in the
275 %                       range [0, log2(N)-1], where "0" implies
276 %                       full decomposition.
277 %
278 %  OUTPUTS:
279 %
280 %   Tforward        --> (N x N) Forward transform matrix
281 %
282 %   Tinverse        --> (N x N) Inverse transform matrix
283 %
284
285 if exist('dec_levels') ~= 1,
286     dec_levels = 0;
287 end
288
289 if N == 1,
290     Tforward = 1;
291 elseif strcmp(transform_type, 'hadamard') == 1,
292     Tforward    = hadamard(N);
293 elseif (N == 8) & strcmp(transform_type, 'bior1.5')==1 % hardcoded transform so that the wavelet toolbox is not needed to generate it
294     Tforward =  [ 0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274;
295        0.219417649252501   0.449283757993216   0.449283757993216   0.219417649252501  -0.219417649252501  -0.449283757993216  -0.449283757993216  -0.219417649252501;
296        0.569359398342846   0.402347308162278  -0.402347308162278  -0.569359398342846  -0.083506045090284   0.083506045090284  -0.083506045090284   0.083506045090284;
297       -0.083506045090284   0.083506045090284  -0.083506045090284   0.083506045090284   0.569359398342846   0.402347308162278  -0.402347308162278  -0.569359398342846;
298        0.707106781186547  -0.707106781186547                   0                   0                   0                   0                   0                   0;
299                        0                   0   0.707106781186547  -0.707106781186547                   0                   0                   0                   0;
300                        0                   0                   0                   0   0.707106781186547  -0.707106781186547                   0                   0;
301                        0                   0                   0                   0                   0                   0   0.707106781186547  -0.707106781186547];   
302 elseif (N == 8) & strcmp(transform_type, 'dct')==1 % hardcoded transform so that the signal processing toolbox is not needed to generate it
303     Tforward = [ 0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274   0.353553390593274;
304        0.490392640201615   0.415734806151273   0.277785116509801   0.097545161008064  -0.097545161008064  -0.277785116509801  -0.415734806151273  -0.490392640201615;
305        0.461939766255643   0.191341716182545  -0.191341716182545  -0.461939766255643  -0.461939766255643  -0.191341716182545   0.191341716182545   0.461939766255643;
306        0.415734806151273  -0.097545161008064  -0.490392640201615  -0.277785116509801   0.277785116509801   0.490392640201615   0.097545161008064  -0.415734806151273;
307        0.353553390593274  -0.353553390593274  -0.353553390593274   0.353553390593274   0.353553390593274  -0.353553390593274  -0.353553390593274   0.353553390593274;
308        0.277785116509801  -0.490392640201615   0.097545161008064   0.415734806151273  -0.415734806151273  -0.097545161008064   0.490392640201615  -0.277785116509801;
309        0.191341716182545  -0.461939766255643   0.461939766255643  -0.191341716182545  -0.191341716182545   0.461939766255643  -0.461939766255643   0.191341716182545;
310        0.097545161008064  -0.277785116509801   0.415734806151273  -0.490392640201615   0.490392640201615  -0.415734806151273   0.277785116509801  -0.097545161008064];
311 elseif (N == 8) & strcmp(transform_type, 'dst')==1 % hardcoded transform so that the PDE toolbox is not needed to generate it
312     Tforward = [ 0.161229841765317   0.303012985114696   0.408248290463863   0.464242826880013   0.464242826880013   0.408248290463863   0.303012985114696   0.161229841765317;
313        0.303012985114696   0.464242826880013   0.408248290463863   0.161229841765317  -0.161229841765317  -0.408248290463863  -0.464242826880013  -0.303012985114696;
314        0.408248290463863   0.408248290463863                   0  -0.408248290463863  -0.408248290463863                   0   0.408248290463863   0.408248290463863;
315        0.464242826880013   0.161229841765317  -0.408248290463863  -0.303012985114696   0.303012985114696   0.408248290463863  -0.161229841765317  -0.464242826880013;
316        0.464242826880013  -0.161229841765317  -0.408248290463863   0.303012985114696   0.303012985114696  -0.408248290463863  -0.161229841765317   0.464242826880013;
317        0.408248290463863  -0.408248290463863                   0   0.408248290463863  -0.408248290463863                   0   0.408248290463863  -0.408248290463863;
318        0.303012985114696  -0.464242826880013   0.408248290463863  -0.161229841765317  -0.161229841765317   0.408248290463863  -0.464242826880013   0.303012985114696;
319        0.161229841765317  -0.303012985114696   0.408248290463863  -0.464242826880013   0.464242826880013  -0.408248290463863   0.303012985114696  -0.161229841765317];
320 elseif strcmp(transform_type, 'dct') == 1,
321     Tforward    = dct(eye(N));
322 elseif strcmp(transform_type, 'dst') == 1,
323     Tforward    = dst(eye(N));
324 elseif strcmp(transform_type, 'DCrand') == 1,
325     x = randn(N); x(1:end,1) = 1; [Q,R] = qr(x); 
326     if (Q(1) < 0), 
327         Q = -Q; 
328     end;
329     Tforward = Q';
330 else %% a wavelet decomposition supported by 'wavedec'
331     %%% Set periodic boundary conditions, to preserve bi-orthogonality
332     dwtmode('per','nodisp');  
333     
334     Tforward = zeros(N,N);
335     for i = 1:N
336         Tforward(:,i)=wavedec(circshift([1 zeros(1,N-1)],[dec_levels i-1]), log2(N), transform_type);  %% construct transform matrix
337     end
338 end
339
340 %%% Normalize the basis elements
341 Tforward = (Tforward' * diag(sqrt(1./sum(Tforward.^2,2))))'; 
342
343 %%% Compute the inverse transform matrix
344 Tinverse = inv(Tforward);
345
346 return;