Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions neo/core/epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ class Epoch(DataObject):
dtype='|S4')

*Required attributes/properties*:
:times: (quantity array 1D) The start times of each time period.
:durations: (quantity array 1D or quantity scalar) The length(s) of each time period.
:times: (quantity array 1D, numpy array 1D or list) The start times
of each time period.
:durations: (quantity array 1D, numpy array 1D, list, or quantity scalar)
The length(s) of each time period.
If a scalar, the same value is used for all time periods.
:labels: (numpy.array 1D dtype='S') Names or labels for the time periods.
:labels: (numpy.array 1D dtype='S' or list) Names or labels for the time periods.

*Recommended attributes/properties*:
:name: (str) A label for the dataset,
Expand All @@ -87,6 +89,10 @@ def __new__(cls, times=None, durations=None, labels=None, units=None, name=None,
description=None, file_origin=None, array_annotations=None, **annotations):
if times is None:
times = np.array([]) * pq.s
elif isinstance(times, (list, tuple)):
times = np.array(times)
if isinstance(durations, (list, tuple)):
durations = np.array(durations)
if durations is None:
durations = np.array([]) * pq.s
elif durations.size != times.size:
Expand All @@ -112,6 +118,8 @@ def __new__(cls, times=None, durations=None, labels=None, units=None, name=None,
dim = units.dimensionality
else:
dim = pq.quantity.validate_dimensionality(units)
if not hasattr(durations, "dimensionality"):
durations = pq.Quantity(durations, dim)
# check to make sure the units are time
# this approach is much faster than comparing the
# reference dimensionality
Expand Down Expand Up @@ -189,8 +197,7 @@ def __getitem__(self, i):
'''
Get the item or slice :attr:`i`.
'''
obj = Epoch(times=super(Epoch, self).__getitem__(i))
obj._copy_data_complement(self)
obj = super(Epoch, self).__getitem__(i)
obj._durations = self.durations[i]
if self._labels is not None and self._labels.size > 0:
obj._labels = self.labels[i]
Expand All @@ -199,8 +206,11 @@ def __getitem__(self, i):
try:
# Array annotations need to be sliced accordingly
obj.array_annotate(**deepcopy(self.array_annotations_at_index(i)))
obj._copy_data_complement(self)
except AttributeError: # If Quantity was returned, not Epoch
pass
obj.times = obj
obj.durations = obj._durations
obj.labels = obj._labels
return obj

def __getslice__(self, i, j):
Expand Down
11 changes: 7 additions & 4 deletions neo/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ class Event(DataObject):
dtype='|S5')

*Required attributes/properties*:
:times: (quantity array 1D) The time of the events.
:labels: (numpy.array 1D dtype='S') Names or labels for the events.
:times: (quantity array 1D, numpy array 1D or list) The times of the events.
:labels: (numpy.array 1D dtype='S' or list) Names or labels for the events.

*Recommended attributes/properties*:
:name: (str) A label for the dataset.
Expand All @@ -81,6 +81,8 @@ def __new__(cls, times=None, labels=None, units=None, name=None, description=Non
file_origin=None, array_annotations=None, **annotations):
if times is None:
times = np.array([]) * pq.s
elif isinstance(times, (list, tuple)):
times = np.array(times)
if labels is None:
labels = np.array([], dtype='S')
else:
Expand Down Expand Up @@ -211,7 +213,7 @@ def _copy_data_complement(self, other):
# Note: Array annotations, including labels, cannot be copied
# because they are linked to their respective timestamps and length of data can be changed
# here which would cause inconsistencies
for attr in ("_labels", "name", "file_origin", "description",
for attr in ("name", "file_origin", "description",
"annotations"):
setattr(self, attr, deepcopy(getattr(other, attr, None)))

Expand All @@ -223,8 +225,9 @@ def __getitem__(self, i):
obj.labels = self._labels
try:
obj.array_annotate(**deepcopy(self.array_annotations_at_index(i)))
obj._copy_data_complement(self)
except AttributeError: # If Quantity was returned, not Event
pass
obj.times = obj
return obj

def set_labels(self, labels):
Expand Down
14 changes: 12 additions & 2 deletions neo/test/coretest/test_epoch.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ def test_Epoch_creation_scalar_duration(self):
assert_arrays_equal(epc.labels,
np.array(['test epoch 1', 'test epoch 2', 'test epoch 3'], dtype='S'))

def test_Epoch_creation_from_lists(self):
epc = Epoch([1.1, 1.5, 1.7],
[20.0, 20.0, 20.0],
['test event 1', 'test event 2', 'test event 3'],
units=pq.ms)
assert_arrays_equal(epc.times, [1.1, 1.5, 1.7] * pq.ms)
assert_arrays_equal(epc.durations, [20.0, 20.0, 20.0] * pq.ms)
assert_arrays_equal(epc.labels,
np.array(['test event 1', 'test event 2', 'test event 3']))

def test_Epoch_repr(self):
params = {'test2': 'y1', 'test3': True}
epc = Epoch([1.1, 1.5, 1.7] * pq.ms, durations=[20, 40, 60] * pq.ns,
Expand Down Expand Up @@ -654,13 +664,13 @@ def test_as_quantity(self):
self.assertIsInstance(epc_as_q, pq.Quantity)
assert_array_equal(times * pq.ms, epc_as_q)

def test_getitem(self):
def test_getitem_scalar(self):
times = [2, 3, 4, 5]
durations = [0.1, 0.2, 0.3, 0.4]
labels = ["A", "B", "C", "D"]
epc = Epoch(times * pq.ms, durations=durations * pq.ms, labels=labels)
single_epoch = epc[2]
self.assertIsInstance(single_epoch, Epoch)
self.assertIsInstance(single_epoch, pq.Quantity)
assert_array_equal(single_epoch.times, np.array([4.0]))
assert_array_equal(single_epoch.durations, np.array([0.3]))
assert_array_equal(single_epoch.labels, np.array(["C"]))
Expand Down
12 changes: 7 additions & 5 deletions neo/test/coretest/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,11 +146,13 @@ def test_Event_creation_invalid_labels(self):
self.assertRaises(ValueError, Event, [1.1, 1.5, 1.7] * pq.ms,
labels=["A", "B"])

def test_Epoch_creation_scalar_duration(self):
# test with scalar for durations
epc = Epoch([1.1, 1.5, 1.7] * pq.ms, durations=20 * pq.ns,
labels=np.array(['test epoch 1', 'test epoch 2', 'test epoch 3'], dtype='S'))
assert_neo_object_is_compliant(epc)
def test_Event_creation_from_lists(self):
evt = Event([1.1, 1.5, 1.7],
['test event 1', 'test event 2', 'test event 3'],
units=pq.ms)
assert_arrays_equal(evt.times, [1.1, 1.5, 1.7] * pq.ms)
assert_arrays_equal(evt.labels,
np.array(['test event 1', 'test event 2', 'test event 3']))

def tests_time_slice(self):

Expand Down