Source code for rythm_forge.display.spectrogram

import matplotlib.pyplot as plt
import numpy as np
from ..exceptions.exceptions import RythmForgeValueError, RythmForgeTypeError


[docs] def spectrogram( data: np.ndarray, sr=44100, hop_length=512, yaxis=None, time=True, fmin=None, fmax=None, ax=None, ) -> None: """ Displays a precomputed spectrogram in linear, log, or mel scale on the y-axis. Parameters: ----------- data : np.ndarray The STFT matrix (2D array) containing the magnitude values of the spectrogram. Must be of type np.floating. sr : int, optional The sampling rate of the original audio signal. Default is 44100 Hz. hop_length : int, optional The number of samples between successive frames in the STFT. Default is 512. yaxis : str, optional The scale of the y-axis. Options are: - "linear" (default) - "log" - "mel" time : bool, optional If True, the x-axis will be labeled with time in seconds. Default is True. fmin : float, optional The minimum frequency to display on the y-axis. Default is 20 Hz. fmax : float, optional The maximum frequency to display on the y-axis. Default is sr/2. ax : matplotlib.axes.Axes, optional A matplotlib Axes object to plot on. If None, a new figure and axes will be created. Raises: ------- RythmForgeTypeError If the data type of the STFT matrix is not np.float64 or an unknown yaxis type is provided. RythmForgeValueError If the STFT matrix does not have exactly 2 dimensions. Returns: -------- AxesImage """ if not np.issubdtype(data.dtype, np.floating): raise RythmForgeTypeError( "Unsupported data type in STFT matrix! Provide matrix with elements of type np.float64" ) if data.ndim != 2: raise RythmForgeValueError( "Wrong STFT matrix dim number! STFT should have ndim=2" ) if fmin is None: fmin = 20 if fmax is None: fmax = sr / 2 if ax is None: fig, ax = plt.subplots(figsize=(10, 5)) plt.tight_layout() img = ax.imshow( data, aspect="auto", origin="lower", cmap="magma", extent=[0, data.shape[1], fmin, fmax], vmin=-80, vmax=0, ) if yaxis in [None, "linear", "mel"]: plt.yscale("linear") elif yaxis == "log": plt.yscale("log") else: raise RythmForgeTypeError(f"Unknown yscale {yaxis}!") if time: ticks = np.linspace(0, data.shape[1], 9)[:-1] est_times = [round(x * hop_length / sr) for x in ticks] labels = [f"{x // 60}.{x % 60}" for x in est_times] plt.xticks(ticks, labels) plt.xlabel("Time") plt.ylabel("Hz") plt.ylim([fmin, fmax]) return img