Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions neo/io/neomatlabio.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,20 @@ def create_struct_from_obj(self, ob):
struct[childname] = []

# attributes
for i, attr in enumerate(ob._all_attrs):
all_attrs = list(ob._all_attrs)
if hasattr(ob, 'annotations'):
all_attrs.append(('annotations', type(ob.annotations)))

for i, attr in enumerate(all_attrs):
attrname, attrtype = attr[0], attr[1]

# ~ if attrname =='':
# ~ struct['array'] = ob.magnitude
# ~ struct['units'] = ob.dimensionality.string
# ~ continue

if (hasattr(ob, '_quantity_attr') and ob._quantity_attr == attrname):
if (hasattr(ob, '_quantity_attr') and
ob._quantity_attr == attrname):
struct[attrname] = ob.magnitude
struct[attrname + '_units'] = ob.dimensionality.string
continue
Expand All @@ -308,7 +313,7 @@ def create_struct_from_obj(self, ob):

def create_ob_from_struct(self, struct, classname):
cl = class_by_name[classname]
# check if hinerits Quantity
# check if inherits Quantity
# ~ is_quantity = False
# ~ for attr in cl._necessary_attrs:
# ~ if attr[0] == '' and attr[1] == pq.Quantity:
Expand Down Expand Up @@ -372,13 +377,16 @@ def create_ob_from_struct(self, struct, classname):
if attrname.endswith('_units') or attrname == 'units':
# linked with another field
continue
if (hasattr(cl, '_quantity_attr') and cl._quantity_attr == attrname):

if hasattr(cl, '_quantity_attr') and cl._quantity_attr == attrname:
continue

item = getattr(struct, attrname)

attributes = cl._necessary_attrs + cl._recommended_attrs
dict_attributes = {a[0]: a[1:] for a in attributes}
attributes = cl._necessary_attrs + cl._recommended_attrs \
+ (('annotations', dict),)
dict_attributes = dict([(a[0], a[1:]) for a in attributes])

if attrname in dict_attributes:
attrtype = dict_attributes[attrname][0]
if attrtype == datetime:
Expand All @@ -398,6 +406,9 @@ def create_ob_from_struct(self, struct, classname):
item = pq.Quantity(item, units)
else:
item = pq.Quantity(item, units)
elif attrtype == dict:
# FIXME: works but doesn't convert nested struct to dict
item = {fn: getattr(item, fn) for fn in item._fieldnames}
else:
item = attrtype(item)

Expand Down
10 changes: 8 additions & 2 deletions neo/test/iotest/test_neomatlabio.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ class TestNeoMatlabIO(BaseTestIO, unittest.TestCase):
def test_write_read_single_spike(self):
block1 = Block()
seg = Segment('segment1')
spiketrain = SpikeTrain([1] * pq.s, t_stop=10 * pq.s, sampling_rate=1 * pq.Hz)
spiketrain1 = SpikeTrain([1] * pq.s, t_stop=10 * pq.s, sampling_rate=1 * pq.Hz)
spiketrain1.annotate(yep='yop')
block1.segments.append(seg)
seg.spiketrains.append(spiketrain)
seg.spiketrains.append(spiketrain1)

# write block
filename = BaseTestIO.get_filename_path(self, 'matlabiotestfile.mat')
Expand All @@ -35,6 +36,11 @@ def test_write_read_single_spike(self):
self.assertEqual(block1.segments[0].spiketrains[0],
block2.segments[0].spiketrains[0])

# test annotations
spiketrain2 = block2.segments[0].spiketrains[0]
assert 'yep' in spiketrain2.annotations
assert spiketrain2.annotations['yep'] == 'yop'


if __name__ == "__main__":
unittest.main()