Skip to content
32 changes: 31 additions & 1 deletion tests/test_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_1f(self):
"Mismatch in %s" % name,
)

def test_read_flac(self):
def test_read_write_flac(self):
"""
All FLAC formats, multiple signal files in one record.

Expand Down Expand Up @@ -250,6 +250,28 @@ def test_read_flac(self):
f"Mismatch in {name}",
)

# Test file writing
record.wrsamp()
record_write = wfdb.rdrecord("flacformats", physical=False)
assert record == record_write

def test_read_write_flac_multifrequency(self):
"""
Format 516 with multiple signal files and variable samples per frame.
"""
# Check that we can read a record and write it out again
record = wfdb.rdrecord(
"sample-data/mixedsignals",
physical=False,
smooth_frames=False,
)
record.wrsamp(expanded=True)

# Check that result matches the original
record = wfdb.rdrecord("sample-data/mixedsignals", smooth_frames=False)
record_write = wfdb.rdrecord("mixedsignals", smooth_frames=False)
assert record == record_write

def test_read_flac_longduration(self):
"""
Three signals multiplexed in a FLAC file, over 2**24 samples.
Expand Down Expand Up @@ -628,6 +650,14 @@ def tearDownClass(cls):
"100_3chan.hea",
"a103l.hea",
"a103l.mat",
"flacformats.d0",
"flacformats.d1",
"flacformats.d2",
"flacformats.hea",
"mixedsignals.hea",
"mixedsignals_e.dat",
"mixedsignals_p.dat",
"mixedsignals_r.dat",
"s0010_re.dat",
"s0010_re.hea",
"s0010_re.xyz",
Expand Down
31 changes: 30 additions & 1 deletion wfdb/io/_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,35 @@ def get_write_fields(self):

return rec_write_fields, sig_write_fields

def _auto_signal_file_names(self):
fmt = self.fmt or [None] * self.n_sig
spf = self.samps_per_frame or [None] * self.n_sig
num_groups = 0
group_number = []
prev_fmt = prev_spf = None
channels_in_group = 0

for ch_fmt, ch_spf in zip(fmt, spf):
if ch_fmt != prev_fmt:
num_groups += 1
channels_in_group = 0
elif ch_fmt in ("508", "516", "524"):
if channels_in_group >= 8 or ch_spf != prev_spf:
num_groups += 1
channels_in_group = 0
group_number.append(num_groups)
prev_fmt = ch_fmt
prev_spf = ch_spf

if num_groups < 2:
return [self.record_name + ".dat"] * self.n_sig
else:
digits = len(str(group_number[-1]))
return [
self.record_name + "_" + str(g).rjust(digits, "0") + ".dat"
for g in group_number
]

def set_default(self, field):
"""
Set the object's attribute to its default value if it is missing
Expand Down Expand Up @@ -394,7 +423,7 @@ def set_default(self, field):

# Specific dynamic case
if field == "file_name" and self.file_name is None:
self.file_name = self.n_sig * [self.record_name + ".dat"]
self.file_name = self._auto_signal_file_names()
return

item = getattr(self, field)
Expand Down
105 changes: 82 additions & 23 deletions wfdb/io/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,12 +950,11 @@ def wr_dat_files(self, expanded=False, write_dir=""):
dat_offsets[fn],
True,
[self.e_d_signal[ch] for ch in dat_channels[fn]],
self.samps_per_frame,
[self.samps_per_frame[ch] for ch in dat_channels[fn]],
write_dir=write_dir,
)
else:
# Create a copy to prevent overwrite
dsig = self.d_signal.copy()
dsig = self.d_signal
for fn in file_names:
wr_dat_file(
fn,
Expand Down Expand Up @@ -2267,16 +2266,15 @@ def wr_dat_file(
fmt : str
WFDB fmt of the dat file.
d_signal : ndarray
The digital conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
The digital conversion of the signal, as a 2d numpy array.
byte_offset : int
The byte offset of the dat file.
expanded : bool, optional
Whether to transform the `e_d_signal` attribute (True) or
the `d_signal` attribute (False).
d_signal : ndarray, optional
The expanded digital conversion of the signal. Either a 2d numpy
array or a list of 1d numpy arrays.
e_d_signal : ndarray, optional
The expanded digital conversion of the signal, as a list of 1d
numpy arrays.
samps_per_frame : list, optional
The samples/frame for each signal of the dat file.
write_dir : str, optional
Expand All @@ -2287,10 +2285,19 @@ def wr_dat_file(
N/A

"""
file_path = os.path.join(write_dir, file_name)

# Combine list of arrays into single array
if expanded:
n_sig = len(e_d_signal)
sig_len = int(len(e_d_signal[0]) / samps_per_frame[0])
if len(samps_per_frame) != n_sig:
raise ValueError("mismatch between samps_per_frame and e_d_signal")

sig_len = len(e_d_signal[0]) // samps_per_frame[0]
for sig, spf in zip(e_d_signal, samps_per_frame):
if len(sig) != sig_len * spf:
raise ValueError("mismatch in lengths of expanded signals")

# Effectively create MxN signal, with extra frame samples acting
# like extra channels
d_signal = np.zeros((sig_len, sum(samps_per_frame)), dtype="int64")
Expand All @@ -2301,10 +2308,17 @@ def wr_dat_file(
for framenum in range(spf):
d_signal[:, expand_ch] = e_d_signal[ch][framenum::spf]
expand_ch = expand_ch + 1
else:
# Create a copy to prevent overwrite
d_signal = d_signal.copy()

# This n_sig is used for making list items.
# Does not necessarily represent number of signals (ie. for expanded=True)
n_sig = d_signal.shape[1]
# Non-expanded format always has 1 sample per frame
n_sig = d_signal.shape[1]
samps_per_frame = [1] * n_sig

# Total number of samples per frame (equal to number of signals if
# expanded=False, but may be greater for expanded=True)
tsamps_per_frame = d_signal.shape[1]

if fmt == "80":
# convert to 8 bit offset binary form
Expand Down Expand Up @@ -2362,8 +2376,8 @@ def wr_dat_file(
# convert to 16 bit two's complement
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 65536
# Split samples into separate bytes using binary masks
b1 = d_signal & [255] * n_sig
b2 = (d_signal & [65280] * n_sig) >> 8
b1 = d_signal & [255] * tsamps_per_frame
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
# Interweave the bytes so that the same samples' bytes are consecutive
b1 = b1.reshape((-1, 1))
b2 = b2.reshape((-1, 1))
Expand All @@ -2375,9 +2389,9 @@ def wr_dat_file(
# convert to 24 bit two's complement
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 16777216
# Split samples into separate bytes using binary masks
b1 = d_signal & [255] * n_sig
b2 = (d_signal & [65280] * n_sig) >> 8
b3 = (d_signal & [16711680] * n_sig) >> 16
b1 = d_signal & [255] * tsamps_per_frame
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16
# Interweave the bytes so that the same samples' bytes are consecutive
b1 = b1.reshape((-1, 1))
b2 = b2.reshape((-1, 1))
Expand All @@ -2391,10 +2405,10 @@ def wr_dat_file(
# convert to 32 bit two's complement
d_signal[d_signal < 0] = d_signal[d_signal < 0] + 4294967296
# Split samples into separate bytes using binary masks
b1 = d_signal & [255] * n_sig
b2 = (d_signal & [65280] * n_sig) >> 8
b3 = (d_signal & [16711680] * n_sig) >> 16
b4 = (d_signal & [4278190080] * n_sig) >> 24
b1 = d_signal & [255] * tsamps_per_frame
b2 = (d_signal & [65280] * tsamps_per_frame) >> 8
b3 = (d_signal & [16711680] * tsamps_per_frame) >> 16
b4 = (d_signal & [4278190080] * tsamps_per_frame) >> 24
# Interweave the bytes so that the same samples' bytes are consecutive
b1 = b1.reshape((-1, 1))
b2 = b2.reshape((-1, 1))
Expand All @@ -2404,9 +2418,54 @@ def wr_dat_file(
b_write = b_write.reshape((1, -1))[0]
# Convert to un_signed 8 bit dtype to write
b_write = b_write.astype("uint8")

elif fmt in ("508", "516", "524"):
import soundfile

if any(spf != samps_per_frame[0] for spf in samps_per_frame):
raise ValueError(
"All channels in a FLAC signal file must have the same "
"sampling rate and samples per frame"
)
if n_sig > 8:
raise ValueError(
"A single FLAC signal file cannot contain more than 8 channels"
)

d_signal = d_signal.reshape(-1, n_sig, samps_per_frame[0])
d_signal = d_signal.transpose(0, 2, 1)
d_signal = d_signal.reshape(-1, n_sig)

if fmt == "508":
d_signal = d_signal.astype("int16")
np.left_shift(d_signal, 8, out=d_signal)
subtype = "PCM_S8"
elif fmt == "516":
d_signal = d_signal.astype("int16")
subtype = "PCM_16"
elif fmt == "524":
d_signal = d_signal.astype("int32")
np.left_shift(d_signal, 8, out=d_signal)
subtype = "PCM_24"
else:
raise ValueError(f"unknown format ({fmt})")

sf = soundfile.SoundFile(
file_path,
mode="w",
samplerate=96000,
channels=n_sig,
subtype=subtype,
format="FLAC",
)
with sf:
sf.write(d_signal)
return

else:
raise ValueError(
"This library currently only supports writing the following formats: 80, 16, 24, 32"
"This library currently only supports writing the "
"following formats: 80, 16, 24, 32, 508, 516, 524"
)

# Byte offset in the file
Expand All @@ -2421,7 +2480,7 @@ def wr_dat_file(
b_write = np.append(np.zeros(byte_offset, dtype="uint8"), b_write)

# Write the bytes to the file
with open(os.path.join(write_dir, file_name), "wb") as f:
with open(file_path, "wb") as f:
b_write.tofile(f)


Expand Down
10 changes: 8 additions & 2 deletions wfdb/io/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,15 @@ def check_field(self, field, required_channels="all"):
"block_size values must be non-negative integers"
)
elif field == "sig_name":
if re.search(r"\s", item[ch]):
if item[ch][:1].isspace() or item[ch][-1:].isspace():
raise ValueError(
"sig_name strings may not begin or end with "
"whitespace."
)
if re.search(r"[\x00-\x1f\x7f-\x9f]", item[ch]):
raise ValueError(
"sig_name strings may not contain whitespaces."
"sig_name strings may not contain "
"control characters."
)
if len(set(item)) != len(item):
raise ValueError("sig_name strings must be unique.")
Expand Down