Total Variation Deconvolution using Split Bregman
tvreg.c
Go to the documentation of this file.
1 
14 #include <math.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <stdio.h>
18 #include "tvregopt.h"
19 
20 #ifdef MATLAB_MEX_FILE
21 #include "tvregmex.h"
22 #endif
23 
24 #include "dsolve_inc.c"
25 #if defined(TVREG_DENOISE) || defined(TVREG_INPAINT)
26 #include "usolve_gs_inc.c"
27 #endif
28 #ifdef TVREG_DECONV
29 #include "usolve_dct_inc.c"
30 #include "usolve_dft_inc.c"
31 #endif
32 #ifdef TVREG_USEZ
33 #include "zsolve_inc.c"
34 #endif
35 
98 int TvRestore(num *u, const num *f, int Width, int Height, int NumChannels,
99  tvregopt *Opt)
100 {
101  const long NumPixels = ((long)Width) * ((long)Height);
102  const long NumEl = NumPixels * NumChannels;
103  tvregsolver S;
104  usolver USolveFun = NULL;
105  zsolver ZSolveFun = NULL;
106  num DiffNorm;
107  int i, Success = 0, DeconvFlag, DctFlag, Iter;
108 
109  if(!u || !f || u == f || Width < 2 || Height < 2 || NumChannels <= 0)
110  return 0;
111 
112  /*** Set algorithm flags ***********************************************/
113  S.Opt = (Opt) ? *Opt : TvRegDefaultOpt;
114 
115  if(!TvRestoreChooseAlgorithm(&S.UseZ, &DeconvFlag, &DctFlag,
116  &USolveFun, &ZSolveFun, &S.Opt))
117  return 0;
118 
119 #if !defined(TVREG_DENOISE) && !defined(TVREG_INPAINT)
120  if(!DeconvFlag)
121  {
122  if(!S.Opt.VaryingLambda)
123  fprintf(stderr, "Please recompile with TVREG_DENOISE "
124  "for denoising problems.\n");
125  else
126  fprintf(stderr, "Please recompile with TVREG_INPAINT "
127  "for inpainting problems.\n");
128  return 0;
129  }
130 #endif
131 
132  if(S.Opt.VaryingLambda && (S.Opt.LambdaWidth != Width
133  || S.Opt.LambdaHeight != Height))
134  {
135  fprintf(stderr, "Image is %dx%d but lambda is %dx%d.\n",
136  Width, Height, S.Opt.LambdaWidth, S.Opt.LambdaHeight);
137  return 0;
138  }
139 
140  S.u = u;
141  S.f = f;
142  S.Width = S.PadWidth = Width;
143  S.Height = S.PadHeight = Height;
144  S.NumChannels = NumChannels;
145  S.Alpha = ((!S.UseZ) ? S.Opt.Lambda : S.Opt.Gamma2)
146  / S.Opt.Gamma1;
147 
148  /*** Allocate memory ***************************************************/
149  S.d = S.dtilde = NULL;
150 #ifdef TVREG_USEZ
151  S.z = S.ztilde = NULL;
152 #endif
153 #ifdef TVREG_DECONV
154  S.A = S.B = S.ATrans = S.BTrans = S.KernelTrans = S.DenomTrans = NULL;
155  S.TransformA = S.TransformB = S.InvTransformA = S.InvTransformB = NULL;
156 #endif
157 
158  if(!(S.d = (numvec2 *)Malloc(sizeof(numvec2)*NumEl))
159  || !(S.dtilde = (numvec2 *)Malloc(sizeof(numvec2)*NumEl)))
160  goto Catch;
161 
162  if(S.UseZ)
163 #ifndef TVREG_USEZ
164  { /* We need z but do not have it, show error message. */
165  if(S.Opt.NoiseModel != NOISEMODEL_L2)
166  fprintf(stderr, "Please recompile with TVREG_NONGAUSSIAN "
167  "for non-Gaussian noise models.\n");
168  else
169  fprintf(stderr, "Please recompile with TVREG_DECONV and "
170  "TVREG_INPAINT for deconvolution-inpainting problems.\n");
171 
172  goto Catch;
173  }
174 #else
175  { /* Allocate memory for z and ztilde */
176  if(!(S.z = (num *)Malloc(sizeof(num)*NumEl))
177  || !(S.ztilde = (num *)Malloc(sizeof(num)*NumEl)))
178  goto Catch;
179 
180  /* Initialize z = ztilde = u */
181  memcpy(S.z, S.u, sizeof(num)*NumEl);
182  memcpy(S.ztilde, S.u, sizeof(num)*NumEl);
183  }
184 #endif
185 
186  if(!DeconvFlag)
187  S.Ku = u;
188  else
189 #ifndef TVREG_DECONV
190  { /* We need deconvolution but do not have it, show error message. */
191  fprintf(stderr, "Please recompile with TVREG_DECONV "
192  "for deconvolution problems.\n");
193  goto Catch;
194  }
195 #else /* The following applies only for problems with deconvolution */
196  if(DctFlag)
197  { /* Prepare for DCT-based deconvolution */
198  long PadNumPixels = ((long)Width + 1) * ((long)Height + 1);
199 
200  if(!(S.ATrans = (num *)FFT(malloc)(sizeof(num)*NumEl))
201  || !(S.BTrans = (num *)FFT(malloc)(sizeof(num)*NumEl))
202  || !(S.A = (num *)FFT(malloc)(sizeof(num)*NumEl))
203  || !(S.B = (num *)FFT(malloc)(
204  sizeof(num)*PadNumPixels*NumChannels))
205  || !(S.KernelTrans = (num *)FFT(malloc)(sizeof(num)*PadNumPixels))
206  || !(S.DenomTrans = (num *)Malloc(sizeof(num)*NumPixels))
207  || !InitDeconvDct(&S))
208  goto Catch;
209  }
210  else
211  { /* Prepare for Fourier-based deconvolution */
212  long NumTransPixels, NumTransEl, PadNumEl;
213  int TransWidth;
214 
215  S.PadWidth = 2*Width;
216  S.PadHeight = 2*Height;
217  TransWidth = S.PadWidth/2 + 1;
218  NumTransPixels = ((long)TransWidth) * ((long)S.PadHeight);
219  NumTransEl = NumTransPixels * NumChannels;
220  PadNumEl = (((long)S.PadWidth) * S.PadHeight) * NumChannels;
221 
222  if(!(S.ATrans = (num *)FFT(malloc)(sizeof(numcomplex)*NumTransEl))
223  || !(S.BTrans = (num *)FFT(malloc)(sizeof(numcomplex)*NumTransEl))
224  || !(S.A = (num *)FFT(malloc)(sizeof(num)*PadNumEl))
225  || !(S.B = (num *)FFT(malloc)(sizeof(num)*PadNumEl))
226  || !(S.KernelTrans = (num *)
227  FFT(malloc)(sizeof(numcomplex)*NumTransPixels))
228  || !(S.DenomTrans = (num *)Malloc(sizeof(num)*NumTransPixels))
229  || !InitDeconvFourier(&S))
230  goto Catch;
231  }
232 #endif
233 
234  /*** Algorithm initializations *****************************************/
235 
236  /* Set convergence threshold scaled by norm of f */
237  for(i = 0, S.fNorm = 0; i < NumEl; i++)
238  S.fNorm += f[i] * f[i];
239 
240  S.fNorm = (num)sqrt(S.fNorm);
241 
242  if(S.fNorm == 0) /* Special case, input image is zero */
243  {
244  memcpy(u, f, sizeof(num)*NumEl);
245  Success = 1;
246  goto Catch;
247  }
248 
249  /* Initialize d = dtilde = 0 */
250  for(i = 0; i < NumEl; i++)
251  S.d[i].x = S.d[i].y = 0;
252 
253  for(i = 0; i < NumEl; i++)
254  S.dtilde[i].x = S.dtilde[i].y = 0;
255 
256  DiffNorm = (S.Opt.Tol > 0) ? 1000*S.Opt.Tol : 1000;
257  Success = 2;
258 
259  if(S.Opt.PlotFun && !S.Opt.PlotFun(0, 0, DiffNorm,
260  u, Width, Height, NumChannels, S.Opt.PlotParam))
261  goto Catch;
262 
263  /*** Algorithm main loop: Bregman iterations ***************************/
264  for(Iter = 1; Iter <= S.Opt.MaxIter; Iter++)
265  {
266  /* Solve d subproblem and update dtilde */
267  DSolve(&S);
268 
269  /* Solve u subproblem */
270  DiffNorm = USolveFun(&S);
271 
272  if(Iter >= 2 + S.UseZ && DiffNorm < S.Opt.Tol)
273  break;
274 
275 #ifdef TVREG_USEZ
276  /* Solve z subproblem and update ztilde */
277  if(S.UseZ)
278  ZSolveFun(&S);
279 #endif
280 
281  if(S.Opt.PlotFun && !(S.Opt.PlotFun(0, Iter, DiffNorm, u,
282  Width, Height, NumChannels, S.Opt.PlotParam)))
283  goto Catch;
284  }
285  /*** End of main loop **************************************************/
286 
287  Success = (Iter <= S.Opt.MaxIter) ? 1 : 2;
288 
289  if(S.Opt.PlotFun)
290  S.Opt.PlotFun(Success, (Iter <= S.Opt.MaxIter) ? Iter : S.Opt.MaxIter,
291  DiffNorm, u, Width, Height, NumChannels, S.Opt.PlotParam);
292 Catch:
293  /*** Release memory ****************************************************/
294  if(S.dtilde)
295  Free(S.dtilde);
296  if(S.d)
297  Free(S.d);
298 #ifdef TVREG_USEZ
299  if(S.ztilde)
300  Free(S.ztilde);
301  if(S.z)
302  Free(S.z);
303 #endif
304 #ifdef TVREG_DECONV
305  if(DeconvFlag)
306  {
307  if(S.DenomTrans)
308  Free(S.DenomTrans);
309  if(S.KernelTrans)
310  FFT(free)(S.KernelTrans);
311  if(S.B)
312  FFT(free)(S.B);
313  if(S.A)
314  FFT(free)(S.A);
315  if(S.BTrans)
316  FFT(free)(S.BTrans);
317  if(S.ATrans)
318  FFT(free)(S.ATrans);
319 
320  FFT(destroy_plan)(S.InvTransformB);
321  FFT(destroy_plan)(S.TransformB);
322  FFT(destroy_plan)(S.InvTransformA);
323  FFT(destroy_plan)(S.TransformA);
324  FFT(cleanup)();
325  }
326 #endif
327  return Success;
328 }
329 
330 
332 static int IsSymmetric(const num *Kernel, int KernelWidth, int KernelHeight)
333 {
334  int x, xr, y, yr;
335 
336  if(KernelWidth % 2 == 0 || KernelHeight % 2 == 0)
337  return 0;
338 
339  for(y = 0, yr = KernelHeight - 1; y < KernelHeight; y++, yr--)
340  for(x = 0, xr = KernelWidth - 1; x < KernelWidth; x++, xr--)
341  if(Kernel[x + KernelWidth*y] != Kernel[xr + KernelWidth*y]
342  || Kernel[x + KernelWidth*y] != Kernel[x + KernelWidth*yr])
343  return 0;
344 
345  return 1;
346 }
347 
348 
350 static int TvRestoreChooseAlgorithm(int *UseZ, int *DeconvFlag, int *DctFlag,
351  usolver *USolveFun, zsolver *ZSolveFun, const tvregopt *Opt)
352 {
353  if(!Opt)
354  return 0;
355 
356  /* UseZ decides between the simpler d,u splitting or the d,u,z splitting
357  of the problem. ZSolveFun selects the z-subproblem solver. */
358  *UseZ = (Opt->NoiseModel != NOISEMODEL_L2);
359 
360 #ifndef TVREG_USEZ
361  *ZSolveFun = NULL;
362 #else
363  switch(Opt->NoiseModel)
364  {
365  case NOISEMODEL_L2:
366  *ZSolveFun = ZSolveL2;
367  break;
368  case NOISEMODEL_L1:
369  *ZSolveFun = ZSolveL1;
370  break;
371  case NOISEMODEL_POISSON:
372  *ZSolveFun = ZSolvePoisson;
373  break;
374  default:
375  return 0;
376  }
377 #endif
378 
379  /* If there is a kernel, set DeconvFlag */
380  if(Opt->Kernel)
381  {
382  /* Must use d,u,z splitting for deconvolution with
383  spatially-varying lambda */
384  if(Opt->VaryingLambda)
385  *UseZ = 1;
386 
387  *DeconvFlag = 1;
388  /* Use faster DCT solver if kernel is symmetric in both dimensions */
389  *DctFlag = IsSymmetric(Opt->Kernel,
390  Opt->KernelWidth, Opt->KernelHeight);
391  }
392  else
393  *DeconvFlag = *DctFlag = 0;
394 
395  /* Select the u-subproblem solver */
396  if(!*DeconvFlag) /* Gauss-Seidel solver for denoising and inpainting */
397 #if defined(TVREG_DENOISE) || defined(TVREG_INPAINT)
398  *USolveFun = (!Opt->VaryingLambda) ?
399  UGaussSeidelConstantLambda : UGaussSeidelVaryingLambda;
400 #else
401  *USolveFun = NULL;
402 #endif
403 #ifdef TVREG_DECONV
404 #ifdef TVREG_USEZ
405  else if(*UseZ)
406  *USolveFun = (*DctFlag) ? UDeconvDctZ : UDeconvFourierZ;
407 #endif
408  else
409  *USolveFun = (*DctFlag) ? UDeconvDct : UDeconvFourier;
410 #endif
411  return 1;
412 }