2017-12-05 6 views
1

私は音楽生成のためのオートエンコーダーを開発しようとしています。その目的のために、私は音楽的関係を捉える損失関数を開発しようとしています。音楽エンコーディングのスパイラルロス機能

私の現在のアイデアは、システムが異なるオクターブで同じ音符を予測すると、音符が間違っている場合よりも損失が小さくなるはずであるという「螺旋」損失関数です。さらに、BやDからCのような正しい音符に近い音符も小さな損失を持つべきです。これは概念的には、コイルまたはスパイラル上の2点間の距離を求めることで、異なるオクターブ内の同じ音符がコイルに接する線に沿って位置するが、ループ距離によって分離されると考えることができる。

私はPyTorchで作業していますが、私の入力表現は36 x 36のTensorで、行は音符(MIDI範囲48:84、ピアノの中央の3オクターブ)を表し、列はタイムステップ(1列= 1/100秒)。行列の値は0または1のいずれかで、特定の時刻にメモがオンであることを示します。ここで

は、損失の私の現在の実装である:

def SpiralLoss(): 
    def spiral_loss(input, output): 
     loss = Variable(torch.FloatTensor([0])) 
     d = 5 
     r = 10 
     for i in xrange(input.size()[0]): 
      for j in xrange(input.size()[3]): 
       # take along the 1 axis because it's a column vector 
       inval, inind = torch.max(input[i, :, :, j], 1) 
       outval, outind = torch.max(output[i, :, :, j], 1) 
       note_loss = (r*30*(inind%12 - outind%12)).float() 
       octave_loss = (d*(inind/12 - outind/12)).float() 
       loss += torch.sqrt(torch.pow(note_loss, 2) + torch.pow(octave_loss, 2)) 
     return loss 
    return spiral_loss 

この損失の問題は、MAX関数が微分可能ではないということです。私はこの損失を差別化する方法を考えることができず、誰かがアイデアや提案を持っているのだろうかと疑問に思っていましたか?

これがこのような投稿の適切な場所であるかどうかはわかりません。そうでない場合は、より良い場所に向かうすべての点について本当に感謝しています。

ありがとうございます!

+0

入力と出力の4つの次元は何ですか? – McLawrence

+0

yup! NxCxHxW。私はMIDIピアノロール表現を使って作業しています。 Nはバッチの数であり、Cは畳み込み層での使用のためのチャネル(私の場合は1)である。 Hはミディノートのディメンション、Wは時間(ピアノロールの高さと幅) – bgenchel

+0

ああ。私はあなたが試してみることができ、それが動作するかどうか私に知らせる1つの可能な解決策を投稿 – McLawrence

答えて

1

ここで最大限にすることは、差別化のために問題があるだけでなく、出力を最大限に活用し、適切な場所にある場合は、間違った位置でわずかに低い値を罰することはありません。

一般的な考え方は、入力と修正された出力ベクトルの差に対して通常のL1またはL2損失を使用することです。出力には、オクターブを罰するいくつかのウェイトマスクを乗算し、

def create_mask(input_column): 
    r = 10 
    d = 5 
    mask = torch.FloatTensor(input_column.size()) 
    _, max_ind = torch.max(input_column, 0) 
    max_ind = int(max_ind[0]) 
    for i in range(mask.size(0)): 
     mask[i] = r*abs(i-max_ind)%12 + d*abs(i-max_ind)/12 
    return mask 

これはちょうど大まかに書かれたものであり、何か準備ができていませんが、理論的には仕事をするべきです。マスクベクトルはrequires_grad=Falseに設定する必要があります。これは、入力ごとに計算する正確な定数であるためです。したがって、入力には最大値を使用できますが、出力にはmaxを使用しないでください。

私はそれが助けてくれることを願っています!