diff --git a/packages/markitdown/src/markitdown/_markitdown.py b/packages/markitdown/src/markitdown/_markitdown.py index 702b10c68..fc4a7da56 100644 --- a/packages/markitdown/src/markitdown/_markitdown.py +++ b/packages/markitdown/src/markitdown/_markitdown.py @@ -118,6 +118,8 @@ def __init__( self._llm_prompt: Union[str | None] = None self._exiftool_path: Union[str | None] = None self._style_map: Union[str | None] = None + self._transcription_engine: Union[str | None] = None + self._transcription_kwargs: dict = {} # Register the converters self._converters: List[ConverterRegistration] = [] @@ -143,6 +145,11 @@ def enable_builtins(self, **kwargs) -> None: self._llm_prompt = kwargs.get("llm_prompt") self._exiftool_path = kwargs.get("exiftool_path") self._style_map = kwargs.get("style_map") + self._transcription_engine = kwargs.get("transcription_engine") + self._transcription_kwargs = { + k: v for k, v in kwargs.items() + if k.startswith("transcription_") + } if self._exiftool_path is None: self._exiftool_path = os.getenv("EXIFTOOL_PATH") @@ -569,6 +576,11 @@ def _convert( if "exiftool_path" not in _kwargs and self._exiftool_path is not None: _kwargs["exiftool_path"] = self._exiftool_path + + # Copy transcription parameters + for key, value in self._transcription_kwargs.items(): + if key not in _kwargs: + _kwargs[key] = value # Add the list of converters for nested processing _kwargs["_parent_converters"] = self._converters diff --git a/packages/markitdown/src/markitdown/converters/_audio_converter.py b/packages/markitdown/src/markitdown/converters/_audio_converter.py index 3d96b53c8..511ffce9b 100644 --- a/packages/markitdown/src/markitdown/converters/_audio_converter.py +++ b/packages/markitdown/src/markitdown/converters/_audio_converter.py @@ -91,7 +91,23 @@ def convert( # Transcribe if audio_format: try: - transcript = transcribe_audio(file_stream, audio_format=audio_format) + # Extract transcription engine and parameters + engine = kwargs.get("transcription_engine", "google") + + # Build engine_kwargs from all transcription_* parameters + engine_kwargs = {} + for key, value in kwargs.items(): + if key.startswith("transcription_") and key != "transcription_engine": + # Remove 'transcription_' prefix to get the actual parameter name + param_name = key.replace("transcription_", "", 1) + engine_kwargs[param_name] = value + + transcript = transcribe_audio( + file_stream, + audio_format=audio_format, + engine=engine, + **engine_kwargs + ) if transcript: md_content += "\n\n### Audio Transcript:\n" + transcript except MissingDependencyException: diff --git a/packages/markitdown/src/markitdown/converters/_transcribe_audio.py b/packages/markitdown/src/markitdown/converters/_transcribe_audio.py index d558e4629..7d749385a 100644 --- a/packages/markitdown/src/markitdown/converters/_transcribe_audio.py +++ b/packages/markitdown/src/markitdown/converters/_transcribe_audio.py @@ -1,6 +1,6 @@ import io import sys -from typing import BinaryIO +from typing import Any, BinaryIO from .._exceptions import MissingDependencyException # Try loading optional (but in this case, required) dependencies @@ -20,7 +20,57 @@ _dependency_exc_info = sys.exc_info() -def transcribe_audio(file_stream: BinaryIO, *, audio_format: str = "wav") -> str: +def transcribe_audio(file_stream: BinaryIO, *, audio_format: str = "wav", engine: str = "google", **engine_kwargs: Any) -> str: + """ + Transcribe audio to text using various speech recognition engines. + This function is a wrapper around the SpeechRecognition library: https://github.com/Uberi/speech_recognition + + Args: + file_stream: Binary stream of the audio file + audio_format: Format of the audio file. Supported: + - Direct: 'wav', 'aiff', 'flac' + - Converted: 'mp3', 'mp4' + engine: Speech recognition engine to use. Supported: + - 'google': Google Speech Recognition (free, no API key, 1 minute per request, 50 requests per day) (https://pypi.org/project/SpeechRecognition/) + - 'google_cloud': Google Cloud Speech-to-Text (requires credentials_json) (https://cloud.google.com/speech-to-text/docs) + - 'wit': Wit.ai (requires key) (https://wit.ai/docs/http/) + - 'azure': Microsoft Azure (requires key, location) (https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-to-text) + - 'bing': Microsoft Bing (requires key) (https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-to-text) + - 'houndify': Houndify (requires client_id, client_key) [(https://www.houndify.com/docs) + - 'assemblyai': AssemblyAI (requires api_token) https://www.assemblyai.com/docs/) + - 'ibm': IBM Watson (requires key) (https://cloud.ibm.com/docs/speech-to-text) + - 'whisper_api': OpenAI Whisper API (requires api_key) (https://platform.openai.com/docs/api-reference/audio) + - 'sphinx': CMU Sphinx (offline, no API key) (https://cmusphinx.github.io/wiki/) + **engine_kwargs: Engine-specific parameters: + - google_cloud: credentials_json (path to JSON file) + - wit: key (API key) + - azure: key (API key), location (region), profanity (masked/removed/raw) + - bing: key (API key), language + - houndify: client_id, client_key + - assemblyai: api_token (API token) + - ibm: key (API key) + - whisper_api: api_key, model, language, prompt, temperature + + Returns: + Transcribed text or "[No speech detected]" if no speech found + + Raises: + ValueError: Invalid engine or audio format + MissingDependencyException: Required packages not installed + sr.RequestError: API request failed + sr.UnknownValueError: Speech could not be understood + + Examples: + >>> # Google (free) + >>> with open("audio.mp3", "rb") as f: + ... text = transcribe_audio(f, audio_format="mp3", engine="google") + + >>> # Whisper API + >>> with open("audio.wav", "rb") as f: + ... text = transcribe_audio(f, audio_format="wav", + ... engine="whisper_api", + ... api_key="sk-...") + """ # Check for installed dependencies if _dependency_exc_info is not None: raise MissingDependencyException( @@ -45,5 +95,26 @@ def transcribe_audio(file_stream: BinaryIO, *, audio_format: str = "wav") -> str recognizer = sr.Recognizer() with sr.AudioFile(audio_source) as source: audio = recognizer.record(source) - transcript = recognizer.recognize_google(audio).strip() - return "[No speech detected]" if transcript == "" else transcript + + # Validate engine exists + try: + recognize_method = getattr(recognizer, f"recognize_{engine}") + except AttributeError: + raise ValueError( + f"Unsupported engine: '{engine}'. " + f"Supported engines: google, google_cloud, wit, azure, houndify, ibm, whisper_api, sphinx" + ) + + # Perform transcription with engine-specific error handling + try: + transcript = recognize_method(audio, **engine_kwargs).strip() + return "[No speech detected]" if transcript == "" else transcript + except sr.RequestError as e: + # API request failed (network, auth, quota, etc.) + raise ValueError( + f"Speech recognition request failed for engine '{engine}': {e}. " + f"Check your API credentials and network connection." + ) from e + except sr.UnknownValueError: + # Speech was unintelligible + return "[No speech detected]" diff --git a/packages/markitdown/tests/test_transcribe_engines.py b/packages/markitdown/tests/test_transcribe_engines.py new file mode 100644 index 000000000..b87839c20 --- /dev/null +++ b/packages/markitdown/tests/test_transcribe_engines.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 -m pytest +import os +import pytest +from markitdown.converters._transcribe_audio import transcribe_audio + +# This file contains tests for multi-engine speech recognition functionality. +# Tests are skipped in CI and require audio test files and optional API keys. + +skip_transcription = ( + True if os.environ.get("GITHUB_ACTIONS") else False +) # Don't run these tests in CI + +TEST_FILES_DIR = os.path.join(os.path.dirname(__file__), "test_files") + +# Test audio files with expected content +AUDIO_TEST_FILES = [ + ("test.wav", "wav"), + ("test.mp3", "mp3"), + ("test.m4a", "mp4"), # M4A uses MP4 container format +] + + +def get_audio_file(filename: str) -> str: + """Get full path to test audio file.""" + return os.path.join(TEST_FILES_DIR, filename) + + +@pytest.mark.skipif(skip_transcription, reason="do not run speech transcription tests in CI") +class TestEngineGoogle: + """Tests for Google Speech Recognition (free, no API key).""" + + @pytest.mark.parametrize("filename,format", AUDIO_TEST_FILES) + def test_google_basic(self, filename: str, format: str) -> None: + """Test basic Google engine transcription.""" + audio_path = get_audio_file(filename) + + if not os.path.exists(audio_path): + pytest.skip(f"Test file not found: {filename}") + + with open(audio_path, "rb") as f: + result = transcribe_audio( + f, + audio_format=format, + engine="google" + ) + + assert isinstance(result, str) + assert len(result) > 0 + # Note: Result may be "[No speech detected]" for test files without speech + + +@pytest.mark.skipif(skip_transcription, reason="do not run speech transcription tests in CI") +@pytest.mark.skipif( + not os.environ.get("GOOGLE_CLOUD_SPEECH_CREDENTIALS"), + reason="do not run without GOOGLE_CLOUD_SPEECH_CREDENTIALS" +) +class TestEngineGoogleCloud: + """Tests for Google Cloud Speech-to-Text.""" + + def test_google_cloud_basic(self) -> None: + """Test Google Cloud Speech-to-Text.""" + credentials_json = os.environ.get("GOOGLE_CLOUD_SPEECH_CREDENTIALS") + audio_path = get_audio_file("test.wav") + + if not os.path.exists(audio_path): + pytest.skip("test.wav not found") + + with open(audio_path, "rb") as f: + result = transcribe_audio( + f, + audio_format="wav", + engine="google_cloud", + credentials_json=credentials_json + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.skipif(skip_transcription, reason="do not run speech transcription tests in CI") +@pytest.mark.skipif( + not os.environ.get("WIT_AI_KEY"), + reason="do not run without WIT_AI_KEY" +) +class TestEngineWit: + """Tests for Wit.ai Speech Recognition.""" + + def test_wit_basic(self) -> None: + """Test Wit.ai transcription.""" + wit_key = os.environ.get("WIT_AI_KEY") + audio_path = get_audio_file("test.wav") + + if not os.path.exists(audio_path): + pytest.skip("test.wav not found") + + with open(audio_path, "rb") as f: + result = transcribe_audio( + f, + audio_format="wav", + engine="wit", + key=wit_key + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.skipif(skip_transcription, reason="do not run speech transcription tests in CI") +@pytest.mark.skipif( + not os.environ.get("OPENAI_API_KEY"), + reason="do not run without OPENAI_API_KEY" +) +class TestEngineWhisperAPI: + """Tests for OpenAI Whisper API.""" + + def test_whisper_api_basic(self) -> None: + """Test Whisper API transcription.""" + openai_key = os.environ.get("OPENAI_API_KEY") + audio_path = get_audio_file("test.wav") + + if not os.path.exists(audio_path): + pytest.skip("test.wav not found") + + with open(audio_path, "rb") as f: + result = transcribe_audio( + f, + audio_format="wav", + engine="whisper_api", + api_key=openai_key + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +@pytest.mark.skipif(skip_transcription, reason="do not run speech transcription tests in CI") +class TestEngineSphinx: + """Tests for CMU Sphinx (offline).""" + + def test_sphinx_basic(self) -> None: + """Test Sphinx offline transcription.""" + audio_path = get_audio_file("test.wav") + + if not os.path.exists(audio_path): + pytest.skip("test.wav not found") + + try: + with open(audio_path, "rb") as f: + result = transcribe_audio( + f, + audio_format="wav", + engine="sphinx" + ) + + assert isinstance(result, str) + except Exception as e: + # Sphinx requires additional installation + if "pocketsphinx" in str(e).lower(): + pytest.skip("PocketSphinx not installed") + raise + + +@pytest.mark.skipif(skip_transcription, reason="do not run speech transcription tests in CI") +class TestEngineErrors: + """Tests for error handling.""" + + def test_invalid_engine(self) -> None: + """Test that invalid engine raises ValueError.""" + audio_path = get_audio_file("test.wav") + + if not os.path.exists(audio_path): + pytest.skip("test.wav not found") + + with pytest.raises(ValueError, match="Unsupported engine"): + with open(audio_path, "rb") as f: + transcribe_audio( + f, + audio_format="wav", + engine="invalid_engine" + ) + + def test_invalid_audio_format(self) -> None: + """Test that invalid audio format raises ValueError.""" + audio_path = get_audio_file("test.wav") + + if not os.path.exists(audio_path): + pytest.skip("test.wav not found") + + with pytest.raises(ValueError, match="Unsupported audio format"): + with open(audio_path, "rb") as f: + transcribe_audio( + f, + audio_format="invalid_format", + engine="google" + ) + + +@pytest.mark.skipif(skip_transcription, reason="do not run speech transcription tests in CI") +class TestAudioFormats: + """Tests for different audio formats.""" + + @pytest.mark.parametrize("filename,format", AUDIO_TEST_FILES) + def test_supported_formats(self, filename: str, format: str) -> None: + """Test that different audio formats work.""" + audio_path = get_audio_file(filename) + + if not os.path.exists(audio_path): + pytest.skip(f"Test file not found: {filename}") + + # Just test that the format is accepted without errors + with open(audio_path, "rb") as f: + result = transcribe_audio( + f, + audio_format=format, + engine="google" + ) + + assert isinstance(result, str) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])