Total Variation Inpainting using Split Bregman
tvreg.c
Go to the documentation of this file.
1 
14 #include <math.h>
15 #include <stdlib.h>
16 #include <string.h>
17 #include <stdio.h>
18 #include "tvregopt.h"
19 
20 #ifdef MATLAB_MEX_FILE
21 #include "tvregmex.h"
22 #endif
23 
24 #include "dsolve_inc.c"
25 #if defined(TVREG_DENOISE) || defined(TVREG_INPAINT)
26 #include "usolve_gs_inc.c"
27 #endif
28 #ifdef TVREG_DECONV
29 #include "usolve_dct_inc.c"
30 #include "usolve_dft_inc.c"
31 #endif
32 #ifdef TVREG_USEZ
33 #include "zsolve_inc.c"
34 #endif
35 
98 int TvRestore(num *u, const num *f, int Width, int Height, int NumChannels,
99  tvregopt *Opt)
100 {
101  const long NumPixels = ((long)Width) * ((long)Height);
102  const long NumEl = NumPixels * NumChannels;
103  tvregsolver S;
104  usolver USolveFun = NULL;
105  zsolver ZSolveFun = NULL;
106  num DiffNorm;
107  int i, Success = 0, DeconvFlag, DctFlag, Iter;
108 
109  if(!u || !f || u == f || Width < 2 || Height < 2 || NumChannels <= 0)
110  return 0;
111 
112  /*** Set algorithm flags ***********************************************/
113  S.Opt = (Opt) ? *Opt : TvRegDefaultOpt;
114 
115  if(!TvRestoreChooseAlgorithm(&S.UseZ, &DeconvFlag, &DctFlag,
116  &USolveFun, &ZSolveFun, &S.Opt))
117  return 0;
118 
119 #if !defined(TVREG_DENOISE) && !defined(TVREG_INPAINT)
120  if(!DeconvFlag)
121  {
122  if(!S.Opt.VaryingLambda)
123  fprintf(stderr, "Please recompile with TVREG_DENOISE "
124  "for denoising problems.\n");
125  else
126  fprintf(stderr, "Please recompile with TVREG_INPAINT "
127  "for inpainting problems.\n");
128  return 0;
129  }
130 #endif
131 
132  if(S.Opt.VaryingLambda && (S.Opt.LambdaWidth != Width
133  || S.Opt.LambdaHeight != Height))
134  {
135  fprintf(stderr, "Image is %dx%d but lambda is %dx%d.\n",
136  Width, Height, S.Opt.LambdaWidth, S.Opt.LambdaHeight);
137  return 0;
138  }
139 
140  S.u = u;
141  S.f = f;
142  S.Width = S.PadWidth = Width;
143  S.Height = S.PadHeight = Height;
144  S.NumChannels = NumChannels;
145  S.Alpha = ((!S.UseZ) ? S.Opt.Lambda : S.Opt.Gamma2)
146  / S.Opt.Gamma1;
147 
148  /*** Allocate memory ***************************************************/
149  S.d = S.dtilde = NULL;
150 #ifdef TVREG_USEZ
151  S.z = S.ztilde = NULL;
152 #endif
153 #ifdef TVREG_DECONV
154  S.A = S.B = S.ATrans = S.BTrans = S.KernelTrans = S.DenomTrans = NULL;
155  S.TransformA = S.TransformB = S.InvTransformA = S.InvTransformB = NULL;
156 #endif
157 
158  if(!(S.d = (numvec2 *)Malloc(sizeof(numvec2)*NumEl))
159  || !(S.dtilde = (numvec2 *)Malloc(sizeof(numvec2)*NumEl)))
160  goto Catch;
161 
162  if(S.UseZ)
163 #ifndef TVREG_USEZ
164  { /* We need z but do not have it, show error message. */
165  if(S.Opt.NoiseModel != NOISEMODEL_L2)
166  fprintf(stderr, "Please recompile with TVREG_NONGAUSSIAN "
167  "for non-Gaussian noise models.\n");
168  else
169  fprintf(stderr, "Please recompile with TVREG_DECONV and "
170  "TVREG_INPAINT for deconvolution-inpainting problems.\n");
171 
172  goto Catch;
173  }
174 #else
175  { /* Allocate memory for z and ztilde */
176  if(!(S.z = (num *)Malloc(sizeof(num)*NumEl))
177  || !(S.ztilde = (num *)Malloc(sizeof(num)*NumEl)))
178  goto Catch;
179 
180  /* Initialize z = ztilde = u */
181  memcpy(S.z, S.u, sizeof(num)*NumEl);
182  memcpy(S.ztilde, S.u, sizeof(num)*NumEl);
183  }
184 #endif
185 
186  if(!DeconvFlag)
187  S.Ku = u;
188  else
189 #ifndef TVREG_DECONV
190  { /* We need deconvolution but do not have it, show error message. */
191  fprintf(stderr, "Please recompile with TVREG_DECONV "
192  "for deconvolution problems.\n");
193  goto Catch;
194  }
195 #else /* The following applies only for problems with deconvolution */
196  if(DctFlag)
197  { /* Prepare for DCT-based deconvolution */
198  if(!(S.ATrans = (num *)FFT(malloc)(sizeof(num)*NumEl))
199  || !(S.BTrans = (num *)FFT(malloc)(sizeof(num)*NumEl))
200  || !(S.A = (num *)FFT(malloc)(sizeof(num)*NumEl))
201  || !(S.B = (num *)FFT(malloc)(sizeof(num)*NumEl))
202  || !(S.KernelTrans = (num *)FFT(malloc)(sizeof(num)*NumPixels))
203  || !(S.DenomTrans = (num *)Malloc(sizeof(num)*NumPixels))
204  || !InitDeconvDct(&S))
205  goto Catch;
206  }
207  else
208  { /* Prepare for Fourier-based deconvolution */
209  long NumTransPixels, NumTransEl, PadNumEl;
210  int TransWidth;
211 
212  S.PadWidth = 2*Width;
213  S.PadHeight = 2*Height;
214  TransWidth = S.PadWidth/2 + 1;
215  NumTransPixels = ((long)TransWidth) * ((long)S.PadHeight);
216  NumTransEl = NumTransPixels * NumChannels;
217  PadNumEl = (((long)S.PadWidth) * S.PadHeight) * NumChannels;
218 
219  if(!(S.ATrans = (num *)FFT(malloc)(sizeof(numcomplex)*NumTransEl))
220  || !(S.BTrans = (num *)FFT(malloc)(sizeof(numcomplex)*NumTransEl))
221  || !(S.A = (num *)FFT(malloc)(sizeof(num)*PadNumEl))
222  || !(S.B = (num *)FFT(malloc)(sizeof(num)*PadNumEl))
223  || !(S.KernelTrans = (num *)
224  FFT(malloc)(sizeof(numcomplex)*NumTransPixels))
225  || !(S.DenomTrans = (num *)Malloc(sizeof(num)*NumTransPixels))
226  || !InitDeconvFourier(&S))
227  goto Catch;
228  }
229 #endif
230 
231  /*** Algorithm initializations *****************************************/
232 
233  /* Set convergence threshold scaled by norm of f */
234  for(i = 0, S.fNorm = 0; i < NumEl; i++)
235  S.fNorm += f[i] * f[i];
236 
237  S.fNorm = (num)sqrt(S.fNorm);
238 
239  if(S.fNorm == 0) /* Special case, input image is zero */
240  {
241  memcpy(u, f, sizeof(num)*NumEl);
242  Success = 1;
243  goto Catch;
244  }
245 
246  /* Initialize d = dtilde = 0 */
247  for(i = 0; i < NumEl; i++)
248  S.d[i].x = S.d[i].y = 0;
249 
250  for(i = 0; i < NumEl; i++)
251  S.dtilde[i].x = S.dtilde[i].y = 0;
252 
253  DiffNorm = (S.Opt.Tol > 0) ? 1000*S.Opt.Tol : 1000;
254  Success = 2;
255 
256  if(S.Opt.PlotFun && !S.Opt.PlotFun(0, 0, DiffNorm,
257  u, Width, Height, NumChannels, S.Opt.PlotParam))
258  goto Catch;
259 
260  /*** Algorithm main loop: Bregman iterations ***************************/
261  for(Iter = 1; Iter <= S.Opt.MaxIter; Iter++)
262  {
263  /* Solve d subproblem and update dtilde */
264  DSolve(&S);
265 
266  /* Solve u subproblem */
267  DiffNorm = USolveFun(&S);
268 
269  if(Iter >= 2 + S.UseZ && DiffNorm < S.Opt.Tol)
270  break;
271 
272 #ifdef TVREG_USEZ
273  /* Solve z subproblem and update ztilde */
274  if(S.UseZ)
275  ZSolveFun(&S);
276 #endif
277 
278  if(S.Opt.PlotFun && !(S.Opt.PlotFun(0, Iter, DiffNorm, u,
279  Width, Height, NumChannels, S.Opt.PlotParam)))
280  goto Catch;
281  }
282  /*** End of main loop **************************************************/
283 
284  Success = (Iter <= S.Opt.MaxIter) ? 1 : 2;
285 
286  if(S.Opt.PlotFun)
287  S.Opt.PlotFun(Success, (Iter <= S.Opt.MaxIter) ? Iter : S.Opt.MaxIter,
288  DiffNorm, u, Width, Height, NumChannels, S.Opt.PlotParam);
289 Catch:
290  /*** Release memory ****************************************************/
291  if(S.dtilde)
292  Free(S.dtilde);
293  if(S.d)
294  Free(S.d);
295 #ifdef TVREG_USEZ
296  if(S.ztilde)
297  Free(S.ztilde);
298  if(S.z)
299  Free(S.z);
300 #endif
301 #ifdef TVREG_DECONV
302  if(DeconvFlag)
303  {
304  if(S.DenomTrans)
305  Free(S.DenomTrans);
306  if(S.KernelTrans)
307  FFT(free)(S.KernelTrans);
308  if(S.B)
309  FFT(free)(S.B);
310  if(S.A)
311  FFT(free)(S.A);
312  if(S.BTrans)
313  FFT(free)(S.BTrans);
314  if(S.ATrans)
315  FFT(free)(S.ATrans);
316 
317  FFT(destroy_plan)(S.InvTransformB);
318  FFT(destroy_plan)(S.TransformB);
319  FFT(destroy_plan)(S.InvTransformA);
320  FFT(destroy_plan)(S.TransformA);
321  FFT(cleanup)();
322  }
323 #endif
324  return Success;
325 }
326 
327 
329 static int IsSymmetric(const num *Kernel, int KernelWidth, int KernelHeight)
330 {
331  int x, xr, y, yr;
332 
333  if(KernelWidth % 2 == 0 || KernelHeight % 2 == 0)
334  return 0;
335 
336  for(y = 0, yr = KernelHeight - 1; y < KernelHeight; y++, yr--)
337  for(x = 0, xr = KernelWidth - 1; x < KernelWidth; x++, xr--)
338  if(Kernel[x + KernelWidth*y] != Kernel[xr + KernelWidth*y]
339  || Kernel[x + KernelWidth*y] != Kernel[x + KernelWidth*yr])
340  return 0;
341 
342  return 1;
343 }
344 
345 
347 static int TvRestoreChooseAlgorithm(int *UseZ, int *DeconvFlag, int *DctFlag,
348  usolver *USolveFun, zsolver *ZSolveFun, const tvregopt *Opt)
349 {
350  if(!Opt)
351  return 0;
352 
353  /* UseZ decides between the simpler d,u splitting or the d,u,z splitting
354  of the problem. ZSolveFun selects the z-subproblem solver. */
355  *UseZ = (Opt->NoiseModel != NOISEMODEL_L2);
356 
357 #ifndef TVREG_USEZ
358  *ZSolveFun = NULL;
359 #else
360  switch(Opt->NoiseModel)
361  {
362  case NOISEMODEL_L2:
363  *ZSolveFun = ZSolveL2;
364  break;
365  case NOISEMODEL_L1:
366  *ZSolveFun = ZSolveL1;
367  break;
368  case NOISEMODEL_POISSON:
369  *ZSolveFun = ZSolvePoisson;
370  break;
371  default:
372  return 0;
373  }
374 #endif
375 
376  /* If there is a kernel, set DeconvFlag */
377  if(Opt->Kernel)
378  {
379  /* Must use d,u,z splitting for deconvolution with
380  spatially-varying lambda */
381  if(Opt->VaryingLambda)
382  *UseZ = 1;
383 
384  *DeconvFlag = 1;
385  /* Use faster DCT solver if kernel is symmetric in both dimensions */
386  *DctFlag = IsSymmetric(Opt->Kernel,
387  Opt->KernelWidth, Opt->KernelHeight);
388  }
389  else
390  *DeconvFlag = *DctFlag = 0;
391 
392  /* Select the u-subproblem solver */
393  if(!*DeconvFlag) /* Gauss-Seidel solver for denoising and inpainting */
394 #if defined(TVREG_DENOISE) || defined(TVREG_INPAINT)
395  *USolveFun = (!Opt->VaryingLambda) ?
397 #else
398  *USolveFun = NULL;
399 #endif
400 #ifdef TVREG_DECONV
401 #ifdef TVREG_USEZ
402  else if(*UseZ)
403  *USolveFun = (*DctFlag) ? UDeconvDctZ : UDeconvFourierZ;
404 #endif
405  else
406  *USolveFun = (*DctFlag) ? UDeconvDct : UDeconvFourier;
407 #endif
408  return 1;
409 }