| | import torch |
| | import numpy as np |
| | from torchaudio import functional as F |
| | from transformers.pipelines.audio_utils import ffmpeg_read |
| | from starlette.exceptions import HTTPException |
| | import sys |
| |
|
| | |
| | |
| |
|
| | import logging |
| | logger = logging.getLogger(__name__) |
| |
|
| | def preprocess_inputs(inputs, sampling_rate): |
| | inputs = ffmpeg_read(inputs, sampling_rate) |
| |
|
| | if sampling_rate != 16000: |
| | inputs = F.resample( |
| | torch.from_numpy(inputs), sampling_rate, 16000 |
| | ).numpy() |
| |
|
| | if len(inputs.shape) != 1: |
| | logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}") |
| | raise HTTPException( |
| | status_code=400, |
| | detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}" |
| | ) |
| |
|
| | |
| | diarizer_inputs = torch.from_numpy(inputs).float() |
| | diarizer_inputs = diarizer_inputs.unsqueeze(0) |
| |
|
| | return inputs, diarizer_inputs |
| |
|
| |
|
| | def diarize_audio(diarizer_inputs, diarization_pipeline, parameters): |
| | diarization = diarization_pipeline( |
| | {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate}, |
| | num_speakers=parameters.num_speakers, |
| | min_speakers=parameters.min_speakers, |
| | max_speakers=parameters.max_speakers, |
| | ) |
| |
|
| | segments = [] |
| | for segment, track, label in diarization.itertracks(yield_label=True): |
| | segments.append( |
| | { |
| | "segment": {"start": segment.start, "end": segment.end}, |
| | "track": track, |
| | "label": label, |
| | } |
| | ) |
| |
|
| | |
| | |
| | new_segments = [] |
| | prev_segment = cur_segment = segments[0] |
| |
|
| | for i in range(1, len(segments)): |
| | cur_segment = segments[i] |
| |
|
| | |
| | if cur_segment["label"] != prev_segment["label"] and i < len(segments): |
| | |
| | new_segments.append( |
| | { |
| | "segment": { |
| | "start": prev_segment["segment"]["start"], |
| | "end": cur_segment["segment"]["start"], |
| | }, |
| | "speaker": prev_segment["label"], |
| | } |
| | ) |
| | prev_segment = segments[i] |
| |
|
| | |
| | new_segments.append( |
| | { |
| | "segment": { |
| | "start": prev_segment["segment"]["start"], |
| | "end": cur_segment["segment"]["end"], |
| | }, |
| | "speaker": prev_segment["label"], |
| | } |
| | ) |
| |
|
| | return new_segments |
| |
|
| |
|
| | def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list: |
| | |
| | end_timestamps = np.array( |
| | [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript]) |
| | segmented_preds = [] |
| |
|
| | |
| | for segment in new_segments: |
| | |
| | end_time = segment["segment"]["end"] |
| | |
| | upto_idx = np.argmin(np.abs(end_timestamps - end_time)) |
| |
|
| | if group_by_speaker: |
| | segmented_preds.append( |
| | { |
| | "speaker": segment["speaker"], |
| | "text": "".join( |
| | [chunk["text"] for chunk in transcript[: upto_idx + 1]] |
| | ), |
| | "timestamp": ( |
| | transcript[0]["timestamp"][0], |
| | transcript[upto_idx]["timestamp"][1], |
| | ), |
| | } |
| | ) |
| | else: |
| | for i in range(upto_idx + 1): |
| | segmented_preds.append({"speaker": segment["speaker"], **transcript[i]}) |
| |
|
| | |
| | transcript = transcript[upto_idx + 1:] |
| | end_timestamps = end_timestamps[upto_idx + 1:] |
| |
|
| | if len(end_timestamps) == 0: |
| | break |
| |
|
| | return segmented_preds |
| |
|
| |
|
| | def diarize(diarization_pipeline, file, parameters, asr_outputs): |
| | _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate) |
| |
|
| | segments = diarize_audio( |
| | diarizer_inputs, |
| | diarization_pipeline, |
| | parameters |
| | ) |
| |
|
| | return post_process_segments_and_transcripts( |
| | segments, asr_outputs["chunks"], group_by_speaker=False |
| | ) |