1 """
2 Output functions (see basic.py) and projection-level output functions
3 (see projfns.py) written in C to optimize performance.
4
5 Requires the weave package; without it unoptimized versions are used.
6 """
7
8 from topo.base.cf import CFPOutputFn,CFPOF_Plugin
9 from topo.base.functionfamilies import OutputFn, IdentityOF
10 from topo.base.parameterclasses import Number, ClassSelectorParameter
11 from topo.base.parameterizedobject import ParameterizedObject
12
13 from topo.misc.inlinec import inline, provide_unoptimized_equivalent
14
15 from basic import DivisiveNormalizeL1
16
17 from numpy.oldnumeric import sum
18
19
20 DivisiveNormalizeL1_opt=DivisiveNormalizeL1
21
23 """
24 Performs divisive normalization of the weights of all cfs.
25 Intended to be equivalent to, but faster than,
26 CFPOF_DivisiveNormalizeL1.
27 """
28 single_cf_fn = ClassSelectorParameter(
29 OutputFn,DivisiveNormalizeL1(norm_value=1.0),readonly=True)
30
31 - def __call__(self, iterator, mask, **params):
32 rows,cols = mask.shape
33
34 cfs = iterator.proj._cfs
35
36
37
38 code = """
39 double *x = mask;
40 for (int r=0; r<rows; ++r) {
41 PyObject *cfsr = PyList_GetItem(cfs,r);
42 for (int l=0; l<cols; ++l) {
43 double load = *x++;
44 if (load != 0)
45 {
46 PyObject *cf = PyList_GetItem(cfsr,l);
47 PyObject *weights_obj = PyObject_GetAttrString(cf,"weights");
48 PyObject *slice_obj = PyObject_GetAttrString(cf,"input_sheet_slice");
49 PyObject *sum_obj = PyObject_GetAttrString(cf,"norm_total");
50
51 float *wi = (float *)(((PyArrayObject*)weights_obj)->data);
52 int *slice = (int *)(((PyArrayObject*)slice_obj)->data);
53 double total = PyFloat_AsDouble(sum_obj); // sum of the cf's weights
54
55 int rr1 = *slice++;
56 int rr2 = *slice++;
57 int cc1 = *slice++;
58 int cc2 = *slice;
59
60 // normalize the weights
61 double factor = 1.0/total;
62 int rc = (rr2-rr1)*(cc2-cc1);
63 for (int i=0; i<rc; ++i) {
64 *(wi++) *= factor;
65 }
66
67 // Anything obtained with PyObject_GetAttrString must be explicitly freed
68 Py_DECREF(weights_obj);
69 Py_DECREF(slice_obj);
70 Py_DECREF(sum_obj);
71
72 // Indicate that norm_total is stale
73 PyObject_SetAttrString(cf,"_has_norm_total",Py_False);
74 }
75 }
76 }
77 """
78 inline(code, ['mask','rows','cols','cfs'], local_dict=locals())
79
80
110
111
112 provide_unoptimized_equivalent("CFPOF_DivisiveNormalizeL1_opt","CFPOF_DivisiveNormalizeL1",locals())
113
114
115
116 __all__ = list(set([k for k,v in locals().items() if isinstance(v,type) and
117 (issubclass(v,OutputFn) or issubclass(v,CFPOutputFn))]))
118