Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 57 additions & 48 deletions neo/io/nixio.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def __init__(self, filename, mode="rw"):
"Valid modes: 'ro' (ReadOnly)', 'rw' (ReadWrite),"
" 'ow' (Overwrite).".format(mode))
self.nix_file = nix.File.open(self.filename, filemode, backend="h5py")
self._object_map = dict()
self._neo_map = dict()
self._nix_map = dict()
self._lazy_loaded = list()
self._object_hashes = dict()
self._block_read_counter = 0
Expand Down Expand Up @@ -156,7 +157,7 @@ def read_segment(self, path, cascade=True, lazy=False):
self._read_cascade(nix_group, path, cascade, lazy)
self._update_maps(neo_segment, lazy)
nix_parent = self._get_parent(path)
neo_parent = self._get_mapped_object(nix_parent)
neo_parent = self._neo_map.get(nix_parent.name)
if neo_parent:
neo_segment.block = neo_parent
return neo_segment
Expand All @@ -169,7 +170,7 @@ def read_channelindex(self, path, cascade=True, lazy=False):
self._read_cascade(nix_source, path, cascade, lazy)
self._update_maps(neo_rcg, lazy)
nix_parent = self._get_parent(path)
neo_parent = self._get_mapped_object(nix_parent)
neo_parent = self._neo_map.get(nix_parent.name)
neo_rcg.block = neo_parent
return neo_rcg

Expand All @@ -196,7 +197,7 @@ def read_signal(self, path, lazy=False):
if self._find_lazy_loaded(neo_signal) is None:
self._update_maps(neo_signal, lazy)
nix_parent = self._get_parent(path)
neo_parent = self._get_mapped_object(nix_parent)
neo_parent = self._neo_map.get(nix_parent.name)
neo_signal.segment = neo_parent
return neo_signal

Expand All @@ -212,7 +213,7 @@ def read_eest(self, path, lazy=False):
neo_eest.path = path
self._update_maps(neo_eest, lazy)
nix_parent = self._get_parent(path)
neo_parent = self._get_mapped_object(nix_parent)
neo_parent = self._neo_map.get(nix_parent.name)
neo_eest.segment = neo_parent
return neo_eest

Expand All @@ -233,7 +234,7 @@ def read_unit(self, path, cascade=True, lazy=False):
self._read_cascade(nix_source, path, cascade, lazy)
self._update_maps(neo_unit, lazy)
nix_parent = self._get_parent(path)
neo_parent = self._get_mapped_object(nix_parent)
neo_parent = self._neo_map.get(nix_parent.name)
neo_unit.channel_index = neo_parent
return neo_unit

Expand All @@ -243,7 +244,7 @@ def _block_to_neo(self, nix_block):
neo_block.rec_datetime = datetime.fromtimestamp(
nix_block.created_at
)
self._object_map[nix_block.id] = neo_block
self._neo_map[nix_block.name] = neo_block
return neo_block

def _group_to_neo(self, nix_group):
Expand All @@ -252,7 +253,7 @@ def _group_to_neo(self, nix_group):
neo_segment.rec_datetime = datetime.fromtimestamp(
nix_group.created_at
)
self._object_map[nix_group.id] = neo_segment
self._neo_map[nix_group.name] = neo_segment
return neo_segment

def _source_chx_to_neo(self, nix_source):
Expand All @@ -261,21 +262,24 @@ def _source_chx_to_neo(self, nix_source):
for c in nix_source.sources
if c.type == "neo.channelindex")
chan_names = list(c["neo_name"] for c in chx if "neo_name" in c)
chan_ids = list(c["channel_id"] for c in chx if "channel_id" in c)
if chan_names:
neo_attrs["channel_names"] = chan_names
if chan_ids:
neo_attrs["channel_ids"] = chan_ids
neo_attrs["index"] = np.array([c["index"] for c in chx])
if "coordinates" in chx[0]:
coord_units = chx[0]["coordinates.units"]
coord_values = list(c["coordinates"] for c in chx)
neo_attrs["coordinates"] = pq.Quantity(coord_values, coord_units)
rcg = ChannelIndex(**neo_attrs)
self._object_map[nix_source.id] = rcg
self._neo_map[nix_source.name] = rcg
return rcg

def _source_unit_to_neo(self, nix_unit):
neo_attrs = self._nix_attr_to_neo(nix_unit)
neo_unit = Unit(**neo_attrs)
self._object_map[nix_unit.id] = neo_unit
self._neo_map[nix_unit.name] = neo_unit
return neo_unit

def _signal_da_to_neo(self, nix_da_group, lazy):
Expand Down Expand Up @@ -336,7 +340,7 @@ def _signal_da_to_neo(self, nix_da_group, lazy):
else:
return None
for da in nix_da_group:
self._object_map[da.id] = neo_signal
self._neo_map[da.name] = neo_signal
if lazy_shape:
neo_signal.lazy_shape = lazy_shape
return neo_signal
Expand Down Expand Up @@ -426,13 +430,13 @@ def _mtag_eest_to_neo(self, nix_mtag, lazy):
)
else:
return None
self._object_map[nix_mtag.id] = eest
self._neo_map[nix_mtag.name] = eest
if lazy_shape:
eest.lazy_shape = lazy_shape
return eest

def _read_cascade(self, nix_obj, path, cascade, lazy):
neo_obj = self._object_map[nix_obj.id]
neo_obj = self._neo_map[nix_obj.name]
for neocontainer in getattr(neo_obj, "_child_containers", []):
nixcontainer = self._container_map[neocontainer]
if not hasattr(nix_obj, nixcontainer):
Expand Down Expand Up @@ -460,7 +464,7 @@ def _read_cascade(self, nix_obj, path, cascade, lazy):
parent_block_path = "/" + path.split("/")[1]
parent_block = self._get_object_at(parent_block_path)
ref_das = self._get_referers(nix_obj, parent_block.data_arrays)
ref_signals = self._get_mapped_objects(ref_das)
ref_signals = list(self._neo_map[da.name] for da in ref_das)
# deduplicate by name
ref_signals = list(dict((s.annotations["nix_name"], s)
for s in ref_signals).values())
Expand All @@ -476,7 +480,7 @@ def _read_cascade(self, nix_obj, path, cascade, lazy):
parent_block_path = "/" + path.split("/")[1]
parent_block = self._get_object_at(parent_block_path)
ref_mtags = self._get_referers(nix_obj, parent_block.multi_tags)
ref_sts = self._get_mapped_objects(ref_mtags)
ref_sts = list(self._neo_map[mt.name] for mt in ref_mtags)
for st in ref_sts:
neo_obj.spiketrains.append(st)
st.unit = neo_obj
Expand Down Expand Up @@ -533,7 +537,7 @@ def _write_object(self, obj, loc=""):
nix_name = "neo.{}.{}".format(objtype, self._generate_nix_name())
obj.annotate(nix_name=nix_name)
objpath = loc + containerstr + nix_name
oldhash = self._object_hashes.get(objpath)
oldhash = self._object_hashes.get(nix_name)
if oldhash is None:
try:
oldobj = self.get(objpath, cascade=False, lazy=False)
Expand All @@ -554,9 +558,16 @@ def _write_object(self, obj, loc=""):
if isinstance(obj, pq.Quantity):
self._write_data(nixobj, attr, objpath)
else:
nixobj = self._get_object_at(objpath)
self._object_map[id(obj)] = nixobj
self._object_hashes[objpath] = newhash
nixobj = self._nix_map.get(nix_name)
if nixobj is None:
nixobj = self._get_object_at(objpath)
else:
# object is already in file but may not be linked at objpath
objat = self._get_object_at(objpath)
if not objat:
self._link_nix_obj(nixobj, loc, containerstr)
self._nix_map[nix_name] = nixobj
self._object_hashes[nix_name] = newhash
self._write_cascade(obj, objpath)

def _create_nix_obj(self, loc, attr):
Expand Down Expand Up @@ -615,6 +626,15 @@ def _create_nix_obj(self, loc, attr):
raise ValueError("Unable to create NIX object. Invalid type.")
return nixobj

def _link_nix_obj(self, obj, loc, neocontainer):
parentobj = self._get_object_at(loc)
container = getattr(parentobj,
self._container_map[neocontainer.strip("/")])
if isinstance(obj, list):
container.extend(obj)
else:
container.append(obj)

def write_block(self, bl, loc=""):
"""
Convert ``bl`` to the NIX equivalent and write it to the file.
Expand Down Expand Up @@ -654,7 +674,7 @@ def write_indices(self, chx, loc=""):
:param chx: The Neo ChannelIndex
:param loc: Path to the CHX
"""
nixsource = self._get_mapped_object(chx)
nixsource = self._nix_map[chx.annotations["nix_name"]]
for idx, channel in enumerate(chx.index):
channame = "{}.ChannelIndex{}".format(chx.annotations["nix_name"],
idx)
Expand All @@ -668,10 +688,13 @@ def write_indices(self, chx, loc=""):
)
nixchan.definition = nixsource.definition
chanmd = nixchan.metadata
chanmd["index"] = nix.Value(int(channel))
if len(chx.channel_names):
neochanname = stringify(chx.channel_names[idx])
chanmd["neo_name"] = nix.Value(neochanname)
chanmd["index"] = nix.Value(int(channel))
if len(chx.channel_ids):
chanid = chx.channel_ids[idx]
chanmd["channel_id"] = nix.Value(chanid)
if chx.coordinates is not None:
coords = chx.coordinates[idx]
coordunits = stringify(coords[0].dimensionality)
Expand Down Expand Up @@ -782,27 +805,28 @@ def _create_references(self, block):
NIX objects.
"""
for seg in block.segments:
group = self._get_mapped_object(seg)
group = self._nix_map[seg.annotations["nix_name"]]
group_signals = self._get_contained_signals(group)
for mtag in group.multi_tags:
if mtag.type in ("neo.epoch", "neo.event"):
mtag.references.extend([sig for sig in group_signals
if sig not in mtag.references])
for rcg in block.channel_indexes:
rcgsource = self._get_mapped_object(rcg)
das = self._get_mapped_objects(rcg.analogsignals +
rcg.irregularlysampledsignals)
chidxsrc = self._nix_map[rcg.annotations["nix_name"]]
das = list(self._nix_map[sig.annotations["nix_name"]]
for sig in rcg.analogsignals +
rcg.irregularlysampledsignals)
# flatten nested lists
das = [da for dalist in das for da in dalist]
for da in das:
if rcgsource not in da.sources:
da.sources.append(rcgsource)
if chidxsrc not in da.sources:
da.sources.append(chidxsrc)
for unit in rcg.units:
unitsource = self._get_mapped_object(unit)
unitsource = self._nix_map[unit.annotations["nix_name"]]
for st in unit.spiketrains:
stmtag = self._get_mapped_object(st)
if rcgsource not in stmtag.sources:
stmtag.sources.append(rcgsource)
stmtag = self._nix_map[st.annotations["nix_name"]]
if chidxsrc not in stmtag.sources:
stmtag.sources.append(chidxsrc)
if unitsource not in stmtag.sources:
stmtag.sources.append(unitsource)

Expand Down Expand Up @@ -853,21 +877,6 @@ def _get_parent(self, path):
parent_obj = self._get_object_at(parent_path)
return parent_obj

def _get_mapped_objects(self, object_list):
return list(map(self._get_mapped_object, object_list))

def _get_mapped_object(self, obj):
# We could use paths here instead
try:
if hasattr(obj, "id"):
return self._object_map[obj.id]
else:
return self._object_map[id(obj)]
except KeyError:
# raise KeyError("Failed to find mapped object for {}. "
# "Object not yet converted.".format(obj))
return None

def _write_attr_annotations(self, nixobj, attr, path):
if isinstance(nixobj, list):
metadata = nixobj[0].metadata
Expand Down Expand Up @@ -970,7 +979,8 @@ def _update_maps(self, obj, lazy):
elif not lazy and objidx is not None:
self._lazy_loaded.pop(objidx)
if not lazy:
self._object_hashes[obj.path] = self._hash_object(obj)
nix_name = obj.annotations["nix_name"]
self._object_hashes[nix_name] = self._hash_object(obj)

def _find_lazy_loaded(self, obj):
"""
Expand Down Expand Up @@ -1260,7 +1270,6 @@ def close(self):
self.nix_file and self.nix_file.is_open()):
self.nix_file.close()
self.nix_file = None
self._object_map = None
self._lazy_loaded = None
self._object_hashes = None
self._block_read_counter = None
Expand Down
Loading