Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

""" 

Discrete Wavelet Transform 

  

Wraps ``GSL's gsl_wavelet_transform_forward()``, 

and ``gsl_wavelet_transform_inverse()`` and creates plot methods. 

  

AUTHOR: 

  

- Josh Kantor (2006-10-07) - initial version 

- David Joyner (2006-10-09) - minor changes to docstrings and examples. 

  

""" 

  

#***************************************************************************** 

# Copyright (C) 2006 Joshua Kantor <jkantor@math.washington.edu> 

# 

# This program is free software: you can redistribute it and/or modify 

# it under the terms of the GNU General Public License as published by 

# the Free Software Foundation, either version 2 of the License, or 

# (at your option) any later version. 

# http://www.gnu.org/licenses/ 

#***************************************************************************** 

from __future__ import absolute_import 

  

import sage.plot.all 

  

def WaveletTransform(n, wavelet_type, wavelet_k): 

""" 

This function initializes an GSLDoubleArray of length n which 

can perform a discrete wavelet transform. 

  

INPUT: 

  

- ``n`` -- a power of 2 

- ``T`` -- the data in the GSLDoubleArray must be real 

- ``wavelet_type`` -- the name of the type of wavelet, valid choices are: 

  

* ``'daubechies'`` 

* ``'daubechies_centered'`` 

* ``'haar'`` 

* ``'haar_centered'`` 

* ``'bspline'`` 

* ``'bspline_centered'`` 

  

For daubechies wavelets, ``wavelet_k`` specifies a daubechie wavelet 

with `k/2` vanishing moments. `k = 4,6,...,20` for `k` even are the 

only ones implemented. 

  

For Haar wavelets, ``wavelet_k`` must be 2. 

  

For bspline wavelets, ``wavelet_k`` of `103,105,202,204,206,208,301, 

305,307,309` will give biorthogonal B-spline wavelets of order `(i,j)` 

where ``wavelet_k`` is `100*i+j`. 

The wavelet transform uses `J = \log_2(n)` levels. 

  

OUTPUT: 

  

An array of the form 

`(s_{-1,0}, d_{0,0}, d_{1,0}, d_{1,1}, d_{2,0}, \ldots, d_{J-1,2^{J-1}-1})` 

for `d_{j,k}` the detail coefficients of level `j`. 

The centered forms align the coefficients of the sub-bands on edges. 

  

EXAMPLES:: 

  

sage: a = WaveletTransform(128,'daubechies',4) 

sage: for i in range(1, 11): 

....: a[i] = 1 

....: a[128-i] = 1 

sage: a.plot().show(ymin=0) 

sage: a.forward_transform() 

sage: a.plot().show() 

sage: a = WaveletTransform(128,'haar',2) 

sage: for i in range(1, 11): a[i] = 1; a[128-i] = 1 

sage: a.forward_transform() 

sage: a.plot().show(ymin=0) 

sage: a = WaveletTransform(128,'bspline_centered',103) 

sage: for i in range(1, 11): a[i] = 1; a[100+i] = 1 

sage: a.forward_transform() 

sage: a.plot().show(ymin=0) 

  

This example gives a simple example of wavelet compression:: 

  

sage: a = DWT(2048,'daubechies',6) 

sage: for i in range(2048): a[i]=float(sin((i*5/2048)**2)) 

sage: a.plot().show() # long time (7s on sage.math, 2011) 

sage: a.forward_transform() 

sage: for i in range(1800): a[2048-i-1] = 0 

sage: a.backward_transform() 

sage: a.plot().show() # long time (7s on sage.math, 2011) 

""" 

cdef size_t _n, _k 

_n = int(n) 

if _n < 0: 

raise ValueError("n must be nonnegative.") 

_k = int(wavelet_k) 

if not is2pow(_n): 

raise NotImplementedError("discrete wavelet transform only implemented when n is a 2-power") 

return DiscreteWaveletTransform(_n,1,wavelet_type,_k) 

  

DWT = WaveletTransform 

  

cdef class DiscreteWaveletTransform(GSLDoubleArray): 

""" 

Discrete wavelet transform class. 

""" 

def __cinit__(self,size_t n,size_t stride, wavelet_type, size_t wavelet_k): 

self.wavelet = NULL 

self.workspace = NULL 

  

def __init__(self,size_t n,size_t stride, wavelet_type, size_t wavelet_k): 

if not is2pow(n): 

raise NotImplementedError("discrete wavelet transform only implemented when n is a 2-power") 

GSLDoubleArray.__init__(self,n,stride) 

if wavelet_type=="daubechies": 

self.wavelet = <gsl_wavelet*> gsl_wavelet_alloc(gsl_wavelet_daubechies, wavelet_k) 

elif wavelet_type == "daubechies_centered": 

self.wavelet = <gsl_wavelet*> gsl_wavelet_alloc(gsl_wavelet_daubechies_centered,wavelet_k) 

elif wavelet_type == "haar": 

self.wavelet = <gsl_wavelet *> gsl_wavelet_alloc(gsl_wavelet_haar,wavelet_k) 

elif wavelet_type == "haar_centered": 

self.wavelet = <gsl_wavelet*> gsl_wavelet_alloc(gsl_wavelet_haar_centered,wavelet_k) 

elif wavelet_type == "bspline": 

self.wavelet = <gsl_wavelet*> gsl_wavelet_alloc(gsl_wavelet_bspline,wavelet_k) 

elif wavelet_type == "bspline_centered": 

self.wavelet = <gsl_wavelet*> gsl_wavelet_alloc(gsl_wavelet_bspline_centered,wavelet_k) 

self.workspace = <gsl_wavelet_workspace*> gsl_wavelet_workspace_alloc(n) 

  

def __dealloc__(self): 

if self.wavelet != NULL: 

gsl_wavelet_free(self.wavelet) 

gsl_wavelet_workspace_free(self.workspace) 

  

def forward_transform(self): 

gsl_wavelet_transform_forward(self.wavelet,self.data,self.stride,self.n,self.workspace) 

  

def backward_transform(self): 

gsl_wavelet_transform_inverse(self.wavelet,self.data,self.stride,self.n,self.workspace) 

  

def plot(self,xmin=None,xmax=None,**args): 

cdef int i 

cdef double x 

v = [] 

point = sage.plot.all.point 

if xmin is None: 

x_min = 0 

if xmax is None: 

x_max=self.n 

for i from x_min <=i < x_max: 

x = self.data[i] 

if i >0: 

v.append(point([(i,x)],hue=(1,1,1),**args)) 

return sum(v) 

  

  

def is2pow(unsigned int n): 

while n != 0 and n%2 == 0: 

n = n >> 1 

return n == 1