import math
import torch
from torch import nn
import torch.nn.functional as F
[docs]class MTHiSAN(nn.Module):
'''
Multitask hierarchical self-attention network for classifying cancer pathology reports.
Args:
embedding_matrix (numpy array): Numpy array of word embeddings.
Each row should represent a word embedding.
NOTE: The word index 0 is masked, so the first row is ignored.
num_classes (list[int]): Number of possible output classes for each task.
max_words_per_line (int): Number of words per line.
Used to split documents into smaller chunks.
max_lines (int): Maximum number of lines per document.
Additional lines beyond this limit are ignored.
att_dim_per_head (int, default: 50): Dimension size of output from each attention head.
Total output dimension is att_dim_per_head * att_heads.
att_heads (int, default: 8): Number of attention heads for multihead attention.
att_dropout (float, default: 0.1): Dropout rate for attention softmaxes and intermediate embeddings.
bag_of_embeddings (bool, default: False): Adds a parallel bag of embeddings layer.
Concats to the final document embedding.
embeddings_scale (float, default: 2.5): Scaling of word embeddings matrix columns.
'''
def __init__(self,
embedding_matrix,
num_classes,
max_words_per_line,
max_lines,
att_dim_per_head=50,
att_heads=8,
att_dropout=0.1,
bag_of_embeddings=False,
embeddings_scale=2.5
):
super().__init__()
self.max_words_per_line = max_words_per_line
self.max_lines = max_lines
self.max_len = max_lines * max_words_per_line
self.att_dim_per_head = att_dim_per_head
self.att_heads = att_heads
self.att_dim_total = att_heads * att_dim_per_head
# normalize and initialize embeddings
embedding_matrix -= embedding_matrix.mean(axis=0)
embedding_matrix /= (embedding_matrix.std(axis=0, ddof=1) * embeddings_scale)
embedding_matrix[0] = 0
self.embedding = nn.Embedding.from_pretrained(
torch.tensor(embedding_matrix, dtype=torch.float),
freeze=False)
self.word_embed_drop = nn.Dropout(p=att_dropout)
# Q, K, V, and other layers for word-level self-attention
self.word_q = nn.Linear(embedding_matrix.shape[1], self.att_dim_total)
torch.nn.init.xavier_uniform_(self.word_q.weight)
self.word_q.bias.data.fill_(0.0)
self.word_k = nn.Linear(embedding_matrix.shape[1], self.att_dim_total)
torch.nn.init.xavier_uniform_(self.word_k.weight)
self.word_k.bias.data.fill_(0.0)
self.word_v = nn.Linear(embedding_matrix.shape[1], self.att_dim_total)
torch.nn.init.xavier_uniform_(self.word_v.weight)
self.word_v.bias.data.fill_(0.0)
self.word_att_drop = nn.Dropout(p=att_dropout)
# target vector and other layers for word-level target attention
self.word_target_drop = nn.Dropout(p=att_dropout)
self.word_target = nn.Linear(1, self.att_dim_total, bias=False)
torch.nn.init.uniform_(self.word_target.weight)
self.line_embed_drop = nn.Dropout(p=att_dropout)
# Q, K, V, and other layers for line-level self-attention
self.line_q = nn.Linear(self.att_dim_total, self.att_dim_total)
torch.nn.init.xavier_uniform_(self.line_q.weight)
self.line_q.bias.data.fill_(0.0)
self.line_k = nn.Linear(self.att_dim_total, self.att_dim_total)
torch.nn.init.xavier_uniform_(self.line_k.weight)
self.line_k.bias.data.fill_(0.0)
self.line_v = nn.Linear(self.att_dim_total, self.att_dim_total)
torch.nn.init.xavier_uniform_(self.line_v.weight)
self.line_v.bias.data.fill_(0.0)
# target vector and other layers for line-level target attention
self.line_att_drop = nn.Dropout(p=att_dropout)
self.line_target_drop = nn.Dropout(p=att_dropout)
self.line_target = nn.Linear(1, self.att_dim_total, bias=False)
torch.nn.init.uniform_(self.line_target.weight)
self.doc_embed_drop = nn.Dropout(p=att_dropout)
# optional bag of embeddings layers
self.boe = bag_of_embeddings
if self.boe:
self.boe_dense = nn.Linear(embedding_matrix.shape[1], embedding_matrix.shape[1])
torch.nn.init.xavier_uniform_(self.boe_dense.weight)
self.boe_dense.bias.data.fill_(0.0)
self.boe_drop = nn.Dropout(p=0.5)
# dense classification layers
self.classify_layers = nn.ModuleList()
for n in num_classes:
in_size = self.att_dim_total
if self.boe:
in_size += embedding_matrix.shape[1]
l = nn.Linear(in_size, n)
torch.nn.init.xavier_uniform_(l.weight)
l.bias.data.fill_(0.0)
self.classify_layers.append(l)
def _split_heads(self, x):
'''
Splits the final dimension of a tensor into multiple heads for multihead attention.
Args:
x (torch.tensor): Float tensor of shape [batch_size x seq_len x dim].
Returns:
torch.tensor: Float tensor of shape [batch_size x att_heads x seq_len x att_dim_per_head].
Reshaped tensor for multihead attention.
'''
batch_size = x.size(0)
x = x.view(batch_size, -1, self.att_heads, self.att_dim_per_head)
return torch.transpose(x, 1, 2)
def _attention(self, q, k, v, drop=None, mask_q=None, mask_k=None, mask_v=None):
'''
Flexible attention operation for self and target attention.
Args:
q (torch.tensor): Float tensor of shape [batch x heads x seq_len x dim1].
k (torch.tensor): Float tensor of shape [batch x heads x seq_len x dim1].
v (torch.tensor): Float tensor of shape [batch x heads x seq_len x dim2].
NOTE: q and k must have the same dimension, but v can be different.
drop (torch.nn.Dropout): Dropout layer.
mask_q (torch.tensor): Boolean tensor of shape [batch x seq_len].
mask_k (torch.tensor): Boolean tensor of shape [batch x seq_len].
mask_v (torch.tensor): Boolean tensor of shape [batch x seq_len].
Returns:
None
'''
# generate attention matrix
# batch x heads x seq_len x seq_len
scores = torch.matmul(q, torch.transpose(k, -1, -2)) / math.sqrt(q.size(-1))
# this masks out empty entries in the attention matrix
# and prevents the softmax function from assigning them any attention
if mask_q is not None:
mask_q = torch.unsqueeze(mask_q, 1)
mask_q = torch.unsqueeze(mask_q, -2)
padding_mask = torch.logical_not(mask_q)
scores -= 1.e7 * padding_mask.float()
# normalize attention matrix
weights = F.softmax(scores, -1) # batch x heads x seq_len x seq_len
# this removes empty rows in the normalized attention matrix
# and prevents them from affecting the new output sequence
if mask_k is not None:
mask_k = torch.unsqueeze(mask_k, 1)
mask_k = torch.unsqueeze(mask_k, -1)
weights = torch.mul(weights, mask_k.type(weights.dtype))
# optional attention dropout
if drop is not None:
weights = drop(weights)
# use attention on values to generate new output sequence
result = torch.matmul(weights, v) # batch x heads x seq_len x dim2
# this applies padding to the entries in the output sequence
# and ensures all padded entries are set to 0
if mask_v is not None:
mask_v = torch.unsqueeze(mask_v, 1)
mask_v = torch.unsqueeze(mask_v, -1)
result = torch.mul(result, mask_v.type(result.dtype))
return result
[docs] def forward(self, docs, return_embeds=False):
'''
Flexible attention operation for self and target attention.
Args:
q (torch.tensor): Float tensor of shape [batch x heads x seq_len x dim1].
k (torch.tensor): Float tensor of shape [batch x heads x seq_len x dim1].
v (torch.tensor): Float tensor of shape [batch x heads x seq_len x dim2].
NOTE: q and k must have the same dimension, but v can be different.
drop (torch.nn.Dropout): Dropout layer.
mask_q (torch.tensor): Boolean tensor of shape [batch x seq_len].
mask_k (torch.tensor): Boolean tensor of shape [batch x seq_len].
mask_v (torch.tensor): Boolean tensor of shape [batch x seq_len].
Returns:
None
'''
# bag of embeddings operations if enabled
if self.boe:
mask_words = (docs != 0)
words_per_line = mask_words.sum(-1)
max_words = words_per_line.max()
mask_words = torch.unsqueeze(mask_words[:, :max_words], -1)
docs_input_reduced = docs[:, :max_words]
word_embeds = self.embedding(docs_input_reduced)
word_embeds = torch.mul(word_embeds, mask_words.type(word_embeds.dtype))
bag_embeds = torch.sum(word_embeds, 1)
bag_embeds = torch.mul(bag_embeds,
1.0 / torch.unsqueeze(words_per_line, -1).type(bag_embeds.dtype))
bag_embeds = torch.tanh(self.boe_dense(bag_embeds))
bag_embeds = self.boe_drop(bag_embeds)
# reshape into batch x lines x words
docs = docs[:, :self.max_len]
docs = docs.reshape(-1, self.max_lines, self.max_words_per_line) # batch x max_lines x max_words
# generate masks for word padding and empty lines
# remove extra padding that exists across all documents in batch
mask_words = (docs != 0) # batch x max_lines x max_words
words_per_line = mask_words.sum(-1) # batch x max_lines
max_words = words_per_line.max() # hereon referred to as 'words'
num_lines = (words_per_line != 0).sum(-1) # batch
max_lines = num_lines.max() # hereon referred to as 'lines'
docs_input_reduced = docs[:, :max_lines, :max_words] # batch x lines x words
mask_words = mask_words[:, :max_lines, :max_words] # batch x lines x words
mask_lines = (words_per_line[:, :max_lines] != 0) # batch x lines
# combine batch dim and lines dim for word level functions
# also filter out empty lines for speedup and add them back in later
batch_size = docs_input_reduced.size(0)
docs_input_reduced = docs_input_reduced.reshape(
batch_size*max_lines, max_words) # batch*lines x words
mask_words = mask_words.reshape(batch_size*max_lines, max_words) # batch*lines x words
mask_lines = mask_lines.reshape(batch_size*max_lines) # batch*lines
docs_input_reduced = docs_input_reduced[mask_lines] # filtered x words
mask_words = mask_words[mask_lines] # filtered x words
batch_size_reduced = docs_input_reduced.size(0) # hereon referred to as 'filtered'
# word embeddings
word_embeds = self.embedding(docs_input_reduced) # filtered x words x embed
word_embeds = self.word_embed_drop(word_embeds) # filtered x words x embed
# word self-attention
word_q = F.elu(self._split_heads(self.word_q(word_embeds))) # filtered x heads x words x dim
word_k = F.elu(self._split_heads(self.word_k(word_embeds))) # filtered x heads x words x dim
word_v = F.elu(self._split_heads(self.word_v(word_embeds))) # filtered x heads x words x dim
word_att = self._attention(word_q, word_k, word_v,
self.word_att_drop, mask_words,
mask_words, mask_words) # filtered x heads x words x dim
# word target attention
word_target = self.word_target(word_att.new_ones((1, 1)))
word_target = word_target.view(
1, self.att_heads, 1, self.att_dim_per_head) # 1 x heads x 1 x dim
line_embeds = self._attention(word_target, word_att, word_att,
self.word_target_drop, mask_words) # filtered x heads x 1 x dim
line_embeds = line_embeds.transpose(1, 2).view(
batch_size_reduced, 1, self.att_dim_total).squeeze(1) # filtered x heads*dim
line_embeds = self.line_embed_drop(line_embeds) # filtered x heads*dim
# add in empty lines that were dropped earlier for line level functions
line_embeds_full = line_embeds.new_zeros(
batch_size*max_lines, self.att_dim_total) # batch*lines x heads*dim
line_embeds_full[mask_lines] = line_embeds
line_embeds = line_embeds_full
line_embeds = line_embeds.reshape(
batch_size, max_lines, self.att_dim_total) # batch x lines x heads*dim
mask_lines = mask_lines.reshape(batch_size, max_lines) # batch x lines
# line self-attention
line_q = F.elu(self._split_heads(self.line_q(line_embeds))) # batch x heads x lines x dim
line_k = F.elu(self._split_heads(self.line_k(line_embeds))) # batch x heads x lines x dim
line_v = F.elu(self._split_heads(self.line_v(line_embeds))) # batch x heads x lines x dim
line_att = self._attention(line_q, line_k, line_v,
self.line_att_drop, mask_lines,
mask_lines, mask_lines) # batch x heads x lines x dim
# line target attention
line_target = self.line_target(line_att.new_ones((1, 1)))
line_target = line_target.view(1, self.att_heads,
1, self.att_dim_per_head) # 1 x heads x 1 x dim
doc_embeds = self._attention(line_target, line_att, line_att,
self.line_target_drop, mask_lines) # batch x heads x 1 x dim
doc_embeds = doc_embeds.transpose(1, 2).view(
batch_size, 1, self.att_dim_total).squeeze(1) # batch x heads*dim
doc_embeds = self.doc_embed_drop(doc_embeds) # batch x heads*dim
# if bag of embeddings enabled, concatenate to hisan output
if self.boe:
doc_embeds = torch.cat([doc_embeds, bag_embeds], 1) # batch x heads*dim+embed
# generate logits for each task
logits = []
for l in self.classify_layers:
logits.append(l(doc_embeds)) # batch x num_classes
if return_embeds:
return logits, doc_embeds
return logits