Source code for fresco.models.clc

import math

import torch
from torch import nn
import torch.nn.functional as F


[docs]class CaseLevelContext(nn.Module): def __init__(self, num_classes, doc_embed_size=400, att_dim_per_head=50, att_heads=8, att_dropout=0.1, forward_mask=True, device='cuda' ): super().__init__() self.doc_embed_size = doc_embed_size self.att_dim_per_head = att_dim_per_head self.att_heads = att_heads self.att_dim_total = att_heads * att_dim_per_head self.att_dropout = att_dropout self.forward_mask = forward_mask self.num_tasks = len(num_classes) self.device = device # Q, K, V, and other layers self-attention self.q = nn.Linear(self.doc_embed_size, self.att_dim_total) torch.nn.init.xavier_uniform_(self.q.weight) self.q.bias.data.fill_(0.0) self.k = nn.Linear(self.doc_embed_size, self.att_dim_total) torch.nn.init.xavier_uniform_(self.k.weight) self.k.bias.data.fill_(0.0) self.v = nn.Linear(self.doc_embed_size, self.att_dim_total) torch.nn.init.xavier_uniform_(self.v.weight) self.v.bias.data.fill_(0.0) self.input_drop = nn.Dropout(p=att_dropout) self.att_drop = nn.Dropout(p=att_dropout) self.output_drop = nn.Dropout(p=att_dropout) # prediction layers self.classify_layers = nn.ModuleList() for n in num_classes: l = nn.Linear(self.att_dim_total, n) torch.nn.init.xavier_uniform_(l.weight) l.bias.data.fill_(0.0) self.classify_layers.append(l) def _split_heads(self, x): ''' splits final dim of tensor into multiple heads for multihead attention parameters: - x: torch.tensor (float) [batch_size x seq_len x dim] outputs: - torch.tensor (float) [batch_size x att_heads x seq_len x att_dim_per_head] reshaped tensor for multihead attention ''' batch_size = x.size(0) x = x.view(batch_size, -1, self.att_heads, self.att_dim_per_head) return torch.transpose(x, 1, 2) def _attention(self, q, k, v, drop, mask, max_seq_len): ''' flexible attention operation for self and target attention parameters: - q: torch.tensor (float) [batch x heads x seq_len x dim1] - k: torch.tensor (float) [batch x heads x seq_len x dim1] - v: torch.tensor (float) [batch x heads x seq_len x dim2] NOTE: q and k must have the same dimension, but v can be different - drop: torch.nn.Dropout layer - mask: torch.tensor (bool) [batch x seq_len] - forward_mask: boolean ''' # generate attention matrix # batch x heads x seq_len x seq_len scores = torch.matmul(q, torch.transpose(k, -1, -2)) / math.sqrt(q.size(-1)) # forward mask if self.forward_mask: mask_f = torch.arange(end=max_seq_len, device=self.device)[None, :] <= \ torch.arange(end=max_seq_len, device=self.device)[:, None] mask_f = torch.unsqueeze(mask_f, 0) mask_f = torch.unsqueeze(mask_f, 0) padding_mask = torch.logical_not(mask_f) scores -= 1.e7 * padding_mask.float() # this masks out empty entries in the attention matrix # and prevents the softmax function from assigning them any attention if mask is not None: mask_q = torch.unsqueeze(mask, 1) mask_q = torch.unsqueeze(mask_q, -2) padding_mask = torch.logical_not(mask_q) scores -= 1.e7 * padding_mask.float() # normalize attention matrix # batch x heads x seq_len x seq_len weights = F.softmax(scores, -1) # this removes empty rows in the normalized attention matrix # and prevents them from affecting the new output sequence if mask is not None: mask_k = torch.unsqueeze(mask, 1) mask_k = torch.unsqueeze(mask_k, -1) weights = torch.mul(weights, mask_k.type(weights.dtype)) # optional attention dropout if drop is not None: weights = drop(weights) # use attention on values to generate new output sequence # batch x heads x seq_len x dim2 result = torch.matmul(weights, v) # this applies padding to the entries in the output sequence # and ensures all padded entries are set to 0 if mask is not None: mask_v = torch.unsqueeze(mask, 1) mask_v = torch.unsqueeze(mask_v, -1) result = torch.mul(result, mask_v.type(result.dtype)) return result
[docs] def forward(self, doc_embeds, num_docs): ''' Case level context forward pass. Args: doc_embeds (torch.tensor): Float tensor of shape [batch_size x max_seq_length x doc_embed_size]. Document embeddings. num_docs (torch.tensor): Integer tensor of shape [batch_size]. Number of reports per case. Returns: None ''' # create mask batch_size = num_docs.shape[0] max_seq_len = doc_embeds.shape[1] mask = torch.arange(end=max_seq_len, device=self.device)[None, :] < num_docs[:, None] # self-attention doc_embeds = self.input_drop(doc_embeds) q = F.elu(self._split_heads(self.q(doc_embeds))) # batch x heads x max_seq_len x dim k = F.elu(self._split_heads(self.k(doc_embeds))) # batch x heads x max_seq_len x dim v = F.elu(self._split_heads(self.v(doc_embeds))) # batch x heads x max_seq_len x dim att_out = self._attention(q, k, v, self.att_drop, mask, max_seq_len) # batch x heads x max_seq_len x dim att_out = att_out.transpose(1, 2).reshape( batch_size, max_seq_len, self.att_dim_total) # batch x max_seq_len x heads*dim att_out = self.output_drop(att_out) # batch x max_seq_len x heads*dim # classify logits = [] for _, l in enumerate(self.classify_layers): logit = l(att_out) logits.append(logit) return logits