Package topo :: Package outputfns :: Module optimized
[hide private]
[frames] | no frames]

Source Code for Module topo.outputfns.optimized

  1  """ 
  2  Output functions (see basic.py) and projection-level output functions 
  3  (see projfns.py) written in C to optimize performance. 
  4   
  5  Requires the weave package; without it unoptimized versions are used. 
  6  """ 
  7   
  8  from topo.base.cf import CFPOutputFn,CFPOF_Plugin 
  9  from topo.base.functionfamilies import OutputFn, IdentityOF 
 10  from topo.base.parameterclasses import Number, ClassSelectorParameter 
 11  from topo.base.parameterizedobject import ParameterizedObject 
 12   
 13  from topo.misc.inlinec import inline, provide_unoptimized_equivalent 
 14   
 15  from basic import DivisiveNormalizeL1 
 16   
 17  from numpy.oldnumeric import sum 
 18   
 19  # For backwards compatibility when loading pickled files; can be deleted 
 20  DivisiveNormalizeL1_opt=DivisiveNormalizeL1 
 21   
22 -class CFPOF_DivisiveNormalizeL1_opt(CFPOutputFn):
23 """ 24 Performs divisive normalization of the weights of all cfs. 25 Intended to be equivalent to, but faster than, 26 CFPOF_DivisiveNormalizeL1. 27 """ 28 single_cf_fn = ClassSelectorParameter( 29 OutputFn,DivisiveNormalizeL1(norm_value=1.0),readonly=True) 30
31 - def __call__(self, iterator, mask, **params):
32 rows,cols = mask.shape 33 34 cfs = iterator.proj._cfs 35 # The original code normalized only the CFs for units that were 36 # activated; it might be possible to restore that extra optimization 37 # if some way is found to override that for the first iteration. 38 code = """ 39 double *x = mask; 40 for (int r=0; r<rows; ++r) { 41 PyObject *cfsr = PyList_GetItem(cfs,r); 42 for (int l=0; l<cols; ++l) { 43 double load = *x++; 44 if (load != 0) 45 { 46 PyObject *cf = PyList_GetItem(cfsr,l); 47 PyObject *weights_obj = PyObject_GetAttrString(cf,"weights"); 48 PyObject *slice_obj = PyObject_GetAttrString(cf,"input_sheet_slice"); 49 PyObject *sum_obj = PyObject_GetAttrString(cf,"norm_total"); 50 51 float *wi = (float *)(((PyArrayObject*)weights_obj)->data); 52 int *slice = (int *)(((PyArrayObject*)slice_obj)->data); 53 double total = PyFloat_AsDouble(sum_obj); // sum of the cf's weights 54 55 int rr1 = *slice++; 56 int rr2 = *slice++; 57 int cc1 = *slice++; 58 int cc2 = *slice; 59 60 // normalize the weights 61 double factor = 1.0/total; 62 int rc = (rr2-rr1)*(cc2-cc1); 63 for (int i=0; i<rc; ++i) { 64 *(wi++) *= factor; 65 } 66 67 // Anything obtained with PyObject_GetAttrString must be explicitly freed 68 Py_DECREF(weights_obj); 69 Py_DECREF(slice_obj); 70 Py_DECREF(sum_obj); 71 72 // Indicate that norm_total is stale 73 PyObject_SetAttrString(cf,"_has_norm_total",Py_False); 74 } 75 } 76 } 77 """ 78 inline(code, ['mask','rows','cols','cfs'], local_dict=locals())
79 80
81 -class CFPOF_DivisiveNormalizeL1(CFPOutputFn):
82 """ 83 Non-optimized version of CFOF_DivisiveNormalizeL1_opt1. 84 85 Same as CFPOF_Plugin(single_cf_fn=DivisiveNormalizeL1), except 86 that it supports joint normalization using the norm_total 87 property of ConnectionField. 88 """ 89 90 single_cf_fn = ClassSelectorParameter( 91 OutputFn,default=DivisiveNormalizeL1(norm_value=1.0),constant=True) 92
93 - def __call__(self, iterator, mask, **params):
94 """ 95 Uses the cf.norm_total attribute to allow optimization 96 by computing the sum separately, and to allow joint 97 normalization. After use, cf.norm_total is deleted because 98 the value it would have has been changed. 99 """ 100 if type(self.single_cf_fn) is not IdentityOF: 101 rows,cols = mask.shape 102 single_cf_fn = self.single_cf_fn 103 norm_value = self.single_cf_fn.norm_value 104 for cf,r,c in iterator(): 105 if (mask[r][c] != 0): 106 current_sum=cf.norm_total 107 factor = norm_value/current_sum 108 cf.weights *= factor 109 del cf.norm_total
110 111 112 provide_unoptimized_equivalent("CFPOF_DivisiveNormalizeL1_opt","CFPOF_DivisiveNormalizeL1",locals()) 113 114 115 116 __all__ = list(set([k for k,v in locals().items() if isinstance(v,type) and 117 (issubclass(v,OutputFn) or issubclass(v,CFPOutputFn))])) 118