torchaudio.models¶
models 子包包含用于解决常见音频任务的模型定义。
ConvTasNet¶
-
class
torchaudio.models.ConvTasNet(num_sources: int = 2, enc_kernel_size: int = 16, enc_num_feats: int = 512, msk_kernel_size: int = 3, msk_num_feats: int = 128, msk_num_hidden_feats: int = 512, msk_num_layers: int = 8, msk_num_stacks: int = 3)[source]¶ Conv-TasNet:一个完全基于卷积的时域音频分离网络 1。
- Parameters
num_sources (int) – 分割的源的数量。
编码器/解码器卷积核大小 (int) – 编码器/解码器的卷积核大小,
。 enc_num_feats (int) – 传递给掩码生成器的特征维度,
。 msk_kernel_size (int) – 掩码生成器的卷积核大小,<P>。
msk_num_feats (int) – 掩码生成器中卷积块的输入/输出特征维度,<B, Sc>。
msk_num_hidden_feats (int) – 掩码生成器中卷积块的内部特征维度,<H>。
掩码层数 (int) – 掩码生成器中每个卷积块的层数,<X>。
msk_num_stacks (int) – mask生成器的卷积块数量,<R>。
注意
此实现对应于论文中的“非因果”设置。
-
forward(input: torch.Tensor) → torch.Tensor[source]¶ 执行源分离。生成音频源波形。
- Parameters
input (torch.Tensor) – 形状为 [batch, channel==1, frames] 的 3D 张量
- Returns
形状为 [batch, channel==num_sources, frames] 的三维张量
- Return type
DeepSpeech¶
-
class
torchaudio.models.DeepSpeech(n_feature: int, n_hidden: int = 2048, n_class: int = 40, dropout: float = 0.0)[source]¶ DeepSpeech 模型架构来自 2。
- Parameters
n_feature – 输入特征的数量
n_hidden – 内部隐藏单元大小。
n_class – 输出类别的数量
-
forward(x: torch.Tensor) → torch.Tensor[source]¶ - Parameters
x (torch.Tensor) – 维度为 (batch, channel, time, feature) 的张量。
- Returns
维度为 (batch, time, class) 的预测张量。
- Return type
张量
Wav2Letter¶
-
class
torchaudio.models.Wav2Letter(num_classes: int = 40, input_type: str = 'waveform', num_features: int = 1)[source]¶ Wav2Letter模型架构来自3。
\(\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}\)
- Parameters
-
forward(x: torch.Tensor) → torch.Tensor[source]¶ - Parameters
x (torch.Tensor) – 维度为 (batch_size, num_features, input_length) 的张量。
- Returns
维度为 (batch_size, number_of_classes, input_length) 的预测张量。
- Return type
张量
Wav2Vec2.0¶
Wav2Vec2Model¶
-
class
torchaudio.models.Wav2Vec2Model(feature_extractor: torch.nn.modules.module.Module, encoder: torch.nn.modules.module.Module)[source]¶ 编码器模型用于[4]。
注意
要构建模型,请使用其中一个工厂函数。
- Parameters
feature_extractor (torch.nn.Module) – 从原始音频 Tensor 中提取特征向量的特征提取器。
encoder (torch.nn.Module) – 编码器,将音频特征转换为标签上的概率分布序列(以负对数似然表示)。
-
extract_features(waveforms: torch.Tensor, lengths: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, Optional[torch.Tensor]][source]¶ 从原始波形中提取特征向量
- Parameters
波形图 (张量) – 形状为
(batch, frames)的音频张量。长度 (张量,可选) – 表示批次中每个音频样本的有效长度。 形状:
(batch, )。
- Returns
- Feature vectors.
形状:
(batch, frames, feature dimention)- Tensor, optional:
指示批次中每个特征的有效长度,基于给定的
lengths参数计算。 形状:(batch, )。
- Return type
张量
-
forward(waveforms: torch.Tensor, lengths: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, Optional[torch.Tensor]][source]¶ 计算标签上的概率分布序列。
- Parameters
波形图 (张量) – 形状为
(batch, frames)的音频张量。长度 (张量,可选) – 表示批次中每个音频样本的有效长度。 形状:
(batch, )。
- Returns
- The sequences of probability distribution (in logit) over labels.
形状:
(batch, frames, num labels)。- Tensor, optional:
指示批次中每个特征的有效长度,基于给定的
lengths参数计算。 形状:(batch, )。
- Return type
张量
工厂函数¶
wav2vec2_base¶
-
torchaudio.models.wav2vec2_base(num_out: int) → torchaudio.models.wav2vec2.model.Wav2Vec2Model[source]¶ 使用“Base”配置构建来自[4]的wav2vec2.0模型。
- Parameters
num_out – int 输出标签的数量。
- Returns
生成的模型。
- Return type
- Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters. >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model >>> >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") >>> model = import_huggingface_model(original) >>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt") >>> >>> # Session 2 - Load model and the parameters >>> model = wav2vec2_base(num_out=32) >>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
wav2vec2_large¶
-
torchaudio.models.wav2vec2_large(num_out: int) → torchaudio.models.wav2vec2.model.Wav2Vec2Model[source]¶ 使用“Large”配置构建wav2vec2.0模型,参见 [4]。
- Parameters
num_out – int 输出标签的数量。
- Returns
生成的模型。
- Return type
- Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters. >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model >>> >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h") >>> model = import_huggingface_model(original) >>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt") >>> >>> # Session 2 - Load model and the parameters >>> model = wav2vec2_large(num_out=32) >>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
wav2vec2_large_lv60k¶
-
torchaudio.models.wav2vec2_large_lv60k(num_out: int) → torchaudio.models.wav2vec2.model.Wav2Vec2Model[source]¶ 使用“Large LV-60k”配置构建wav2vec2.0模型 [4]。
- Parameters
num_out – int 输出标签的数量。
- Returns
生成的模型。
- Return type
- Example - Reload fine-tuned model from Hugging Face:
>>> # Session 1 - Convert pretrained model from Hugging Face and save the parameters. >>> from torchaudio.models.wav2vec2.utils import import_huggingface_model >>> >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self") >>> model = import_huggingface_model(original) >>> torch.save(model.state_dict(), "wav2vec2-base-960h.pt") >>> >>> # Session 2 - Load model and the parameters >>> model = wav2vec2_large_lv60k(num_out=32) >>> model.load_state_dict(torch.load("wav2vec2-base-960h.pt"))
实用函数¶
import_huggingface_model¶
-
torchaudio.models.wav2vec2.utils.import_huggingface_model(original: torch.nn.modules.module.Module) → torchaudio.models.wav2vec2.model.Wav2Vec2Model[source]¶ 从Hugging Face的Transformers导入wav2vec2模型。
- Parameters
原始 (torch.nn.Module) –
Wav2Vec2ForCTC类的一个实例,来自transformers。- Returns
已导入模型。
- Return type
- Example
>>> from torchaudio.models.wav2vec2.utils import import_huggingface_model >>> >>> original = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") >>> model = import_huggingface_model(original) >>> >>> waveforms, _ = torchaudio.load("audio.wav") >>> logits, _ = model(waveforms)
import_fairseq_model¶
-
torchaudio.models.wav2vec2.utils.import_fairseq_model(original: torch.nn.modules.module.Module, num_out: Optional[int] = None) → torchaudio.models.wav2vec2.model.Wav2Vec2Model[source]¶ 从fairseq发布的预训练参数构建Wav2Vec2Model。
- Parameters
原始 (torch.nn.Module) – 一个fairseq的Wav2Vec2.0模型类的实例。 可以是
fairseq.models.wav2vec.wav2vec2_asr.Wav2VecEncoder或者fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model。num_out (int, optional) – 输出标签的数量。仅当原始模型是
fairseq.models.wav2vec.wav2vec2.Wav2Vec2Model的实例时需要。
- Returns
已导入模型。
- Return type
- Example - Loading pretrain-only model
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model >>> >>> # Load model using fairseq >>> model_file = 'wav2vec_small.pt' >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) >>> original = model[0] >>> imported = import_fairseq_model(original, num_out=28) >>> >>> # Perform feature extraction >>> waveform, _ = torchaudio.load('audio.wav') >>> features, _ = imported.extract_features(waveform) >>> >>> # Compare result with the original model from fairseq >>> reference = original.feature_extractor(waveform).transpose(1, 2) >>> torch.testing.assert_allclose(features, reference)
- Example - Fine-tuned model
>>> from torchaudio.models.wav2vec2.utils import import_fairseq_model >>> >>> # Load model using fairseq >>> model_file = 'wav2vec_small_960h.pt' >>> model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_file]) >>> original = model[0] >>> imported = import_fairseq_model(original.w2v_encoder) >>> >>> # Perform encoding >>> waveform, _ = torchaudio.load('audio.wav') >>> emission, _ = imported(waveform) >>> >>> # Compare result with the original model from fairseq >>> mask = torch.zeros_like(waveform) >>> reference = original(waveform, mask)['encoder_out'].transpose(0, 1) >>> torch.testing.assert_allclose(emission, reference)
WaveRNN¶
-
class
torchaudio.models.WaveRNN(upsample_scales: List[int], n_classes: int, hop_length: int, n_res_block: int = 10, n_rnn: int = 512, n_fc: int = 512, kernel_size: int = 5, n_freq: int = 128, n_hidden: int = 128, n_output: int = 128)[source]¶ 基于 fatchord 实现的 WaveRNN 模型。
最初的实现介绍在5。 波形和频谱图的输入通道必须为1。乘积 upsample_scales 必须等于 hop_length。
- Parameters
upsample_scales – 上采样比例列表。
n_classes – 输出类别的数量。
hop_length – 连续帧起始点之间的样本数。
n_res_block – 堆叠中 ResBlock 的数量。(默认值:
10)n_rnn – RNN 层的维度。(默认值:
512)n_fc – 全连接层的维度。(默认值:
512)kernel_size – 第一个 Conv1d 层中的卷积核大小数量。(默认值:
5)n_freq – 频谱图中的分箱数量。(默认值:
128)n_hidden – resblock 的隐藏层维度数量。(默认值:
128)n_output – melresnet 的输出维度数量。(默认值:
128)
- Example
>>> wavernn = WaveRNN(upsample_scales=[5,5,8], n_classes=512, hop_length=200) >>> waveform, sample_rate = torchaudio.load(file) >>> # waveform shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length) >>> specgram = MelSpectrogram(sample_rate)(waveform) # shape: (n_batch, n_channel, n_freq, n_time) >>> output = wavernn(waveform, specgram) >>> # output shape: (n_batch, n_channel, (n_time - kernel_size + 1) * hop_length, n_classes)
-
forward(waveform: torch.Tensor, specgram: torch.Tensor) → torch.Tensor[source]¶ 将输入通过 WaveRNN 模型。
- Parameters
waveform – WaveRNN 层的输入波形 (n_batch, 1, (n_time - kernel_size + 1) * hop_length)
specgram – 输入到 WaveRNN 层的频谱图 (n_batch, 1, n_freq, n_time)
- Returns
(批次大小, 1, (时间长度 - 内核大小 + 1) * 滑动步长, 类别数量)
- Return type
张量形状
参考文献¶
- 1
Yi Luo 和 Nima Mesgarani。Conv-tasnet:超越理想时频幅度掩码的语音分离。IEEE/ACM Transactions on Audio, Speech, and Language Processing,27(8):1256–1266,2019年8月。URL: http://dx.doi.org/10.1109/TASLP.2019.2915167,doi:10.1109/taslp.2019.2915167。
- 2
Awni Hannun, Carl Case, Jared Casper, Bryan Catanzaro, Greg Diamos, Erich Elsen, Ryan Prenger, Sanjeev Satheesh, Shubho Sengupta, Adam Coates, 和 Andrew Y. Ng。Deep speech: scaling up end-to-end speech recognition。2014。 arXiv:1412.5567。
- 3
Ronan Collobert, Christian Puhrsch, 和 Gabriel Synnaeve。Wav2letter:一个端到端的基于卷积网络的语音识别系统。2016。 arXiv:1609.03193。
- 4(1,2,3,4)
Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, 和 Michael Auli. Wav2vec 2.0: 一个用于语音表示自监督学习的框架。2020. arXiv:2006.11477.
- 5
Nal Kalchbrenner, Erich Elsen, Karen Simonyan, Seb Noury, Norman Casagrande, Edward Lockhart, Florian Stimberg, Aaron van den Oord, Sander Dieleman, 和 Koray Kavukcuoglu. 高效的神经音频合成。2018. arXiv:1802.08435.