I am trying to get the mean of last 4 layers of BERT deep neural network.
Every hidden layer is of dimension:
outputs[1][-1]=[2,256,768] where 2 is batch size
outputs[1][-2]=[2,256,768] where 2 is batch size
outputs[1][-3]=[2,256,768] where 2 is batch size
outputs[1][-4]=[2,256,768] where 2 is batch size
I want to mean
the 4 layers and output should be of same dimension [2,256,768]
Here开发者_如何学JAVA is my code:
def __init__(self, bert_model, num_labels):
super(BERT_CRF, self).__init__()
self.bert = bert_model
self.dropout = nn.Dropout(0.25)
self.classifier = nn.Linear(768, num_labels)
self.crf = CRF(num_labels, batch_first = True)
def forward(self, input_ids, attention_mask, labels=None, token_type_ids=None):
outputs = self.bert(input_ids, attention_mask=attention_mask)
sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1).mean(dim=[0,1,2])
sequence_output = self.dropout(sequence_output)
emission = self.classifier(sequence_output)
I try to do sequence_output = torch.cat((outputs[1][-1], outputs[1][-2], outputs[1][-3], outputs[1][-4]),-1).mean(dim=[0,1,2])
But it does not give me the result as expected.
You are looking to stack the four tensors and average the newly created dimensions. Since you are looking at the last four elements of outputs[1]
, you can do:
>>> outputs[1,-4:].mean(0)
This would return the average of outputs[1][-1]
, outputs[1][-2]
, outputs[1][-3]
, and outputs[1][-4]
...
精彩评论