Source code for deepke.name_entity_re.few_shot.models.model

import torch
from torch import nn
from torch.nn import functional as F
from transformers.configuration_bart import BartConfig
from .modeling_bart import BartModel, _prepare_bart_decoder_inputs

from ..utils import avg_token_embeddings, seq_to_mask,get_model_device
from functools import partial
from typing import Union


[docs]class PromptBartEncoder(nn.Module): def __init__(self, encoder): super(PromptBartEncoder, self).__init__() self.bart_encoder = encoder
[docs] def forward(self, src_tokens, attention_mask=None, past_key_values=None): encoder_dicts = self.bart_encoder(input_ids=src_tokens, attention_mask=attention_mask, past_key_values=past_key_values, return_dict=True, output_hidden_states=True) return encoder_dicts.last_hidden_state, encoder_dicts.hidden_states
[docs]class PromptBartDecoder(nn.Module): def __init__(self, decoder, pad_token_id, label_ids, use_prompt=False, prompt_len=10, learn_weights=False): super(PromptBartDecoder, self).__init__() self.bart_decoder = decoder self.pad_token_id = pad_token_id self.use_prompt = use_prompt self.prompt_len = prompt_len self.learn_weights = learn_weights self.label_ids = label_ids print(label_ids) if self.learn_weights: # set learnable averge weights self.averge_weights = nn.ParameterList(parameters=None) for id in label_ids: if len(id) > 1: self.averge_weights.append(nn.Parameter(torch.FloatTensor(len(id)).uniform_(1.0, 2.5))) print(self.averge_weights) mapping = [0, 2] for id in label_ids: mapping += id[:1] mapping = torch.LongTensor(mapping) else: mapping = torch.LongTensor([0, 2]+label_ids) self.label_start_id = min(label_ids) self.label_end_id = max(label_ids)+1 self.register_buffer('mapping', mapping) self.src_start_index = len(mapping) hidden_size = decoder.embed_tokens.weight.size(1) self.bart_mlp = nn.Sequential(nn.Linear(hidden_size, hidden_size), nn.Dropout(0.3), nn.ReLU(), nn.Linear(hidden_size, hidden_size)) self.dropout_layer = nn.Dropout(0.3)
[docs] def forward(self, tgt_tokens, prompt_state): cumsum = tgt_tokens.eq(1).flip(dims=[1]).cumsum(dim=-1) tgt_pad_mask = cumsum.flip(dims=[-1]).ne(cumsum[:, -1:]) encoder_outputs = prompt_state.encoder_output # last_hidden_state attention_mask = prompt_state.encoder_mask # attention_mask first = prompt_state.first src_tokens = prompt_state.src_tokens past_key_values = prompt_state.past_key_values # mapping target tokens mapping_token_mask = tgt_tokens.lt(self.src_start_index) mapped_tokens = tgt_tokens.masked_fill(tgt_tokens.ge(self.src_start_index), 0) tag_mapped_tokens = self.mapping[mapped_tokens] src_tokens_index = tgt_tokens - self.src_start_index # bsz x num_src_token src_tokens_index = src_tokens_index.masked_fill(src_tokens_index.lt(0), 0) if first is not None: src_tokens = src_tokens.gather(index=first, dim=1) word_mapped_tokens = src_tokens.gather(index=src_tokens_index, dim=1) tokens = torch.where(mapping_token_mask, tag_mapped_tokens, word_mapped_tokens) # bsz x max_len tokens = tokens.masked_fill(tgt_pad_mask, self.pad_token_id) decoder_input_ids, _, causal_mask = _prepare_bart_decoder_inputs( self.pad_token_id, tokens, decoder_input_ids=None, decoder_padding_mask=None, causal_mask_dtype=self.bart_decoder.embed_tokens.weight.dtype ) if self.use_prompt: assert past_key_values is not None _, _, seqlen, _ = past_key_values[0]['self']['prev_value'].shape tgt_len = decoder_input_ids.size(1) temp_mask = torch.zeros(tgt_len, seqlen).to(causal_mask.device) #tgtlen, preseqlen causal_mask = torch.cat([temp_mask, causal_mask],dim=1) #tgtlen, preseqlen+tgtlen if self.training: tokens = tokens[:, :-1] decoder_pad_mask = tokens.eq(self.pad_token_id) dict = self.bart_decoder(input_ids=tokens, encoder_hidden_states=encoder_outputs, # last_hidden_state encoder_padding_mask=attention_mask, # attention_mask decoder_padding_mask=decoder_pad_mask, decoder_causal_mask=causal_mask[:tokens.size(1), :self.prompt_len+tokens.size(1)], output_hidden_states=True, past_key_values=past_key_values, return_dict=True) else: past_key_values = prompt_state.past_key_values dict = self.bart_decoder(input_ids=tokens, encoder_hidden_states=encoder_outputs, encoder_padding_mask=attention_mask, decoder_padding_mask=None, decoder_causal_mask=None, past_key_values=past_key_values, use_cache=True, return_dict=True) hidden_state = dict.last_hidden_state # bsz x max_len x hidden_size hidden_state = self.dropout_layer(hidden_state) if not self.training: prompt_state.past_key_values = dict.past_key_values logits = hidden_state.new_full((hidden_state.size(0), hidden_state.size(1), self.src_start_index+src_tokens.size(-1)), fill_value=-1e24) # compute eos scores eos_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[2:3])) # bsz x max_len x 1 if self.learn_weights: # use averge_weights compute entity labels scores tag_scores = None idx = 0 for ids in self.label_ids: # bsz x max_len x num_class if len(ids) <= 1: temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[ids])) else: weight = F.softmax(self.averge_weights[idx]) temp_score = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[0]]])) * weight[0] for i in range(1, len(ids)): temp_score = temp_score + F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[[ids[i]]])) * weight[i] idx += 1 if tag_scores is None: tag_scores = temp_score else: tag_scores = torch.cat((tag_scores, temp_score), dim=2) else: tag_scores = F.linear(hidden_state, self.dropout_layer(self.bart_decoder.embed_tokens.weight[self.label_start_id:self.label_end_id])) # bsz x max_len x num_class # bsz x max_bpe_len x hidden_size src_outputs = encoder_outputs if hasattr(self, 'encoder_mlp'): src_outputs = self.encoder_mlp(src_outputs) if first is not None: mask = first.eq(0) # bsz x 1 x max_word_len # bsz x max_word_len x hidden_size src_outputs = src_outputs.gather(index=first.unsqueeze(2).repeat(1, 1, src_outputs.size(-1)), dim=1) else: mask = attention_mask.eq(0) # src_outputs = self.decoder.embed_tokens(src_tokens) mask = mask.unsqueeze(1) input_embed = self.dropout_layer(self.bart_decoder.embed_tokens(src_tokens)) # bsz x max_word_len x hidden_size src_outputs = (src_outputs + input_embed)/2 word_scores = torch.einsum('blh,bnh->bln', hidden_state, src_outputs) # bsz x max_len x max_word_len mask = mask.__or__(src_tokens.eq(2).cumsum(dim=1).ge(1).unsqueeze(1)) word_scores = word_scores.masked_fill(mask, -1e32) logits[:, :, 1:2] = eos_scores logits[:, :, 2:self.src_start_index] = tag_scores logits[:, :, self.src_start_index:] = word_scores return logits, prompt_state
[docs] def decode(self, tokens, state): return self(tokens, state)[0][:, -1]
[docs]class PromptBartModel(nn.Module): def __init__(self, tokenizer, label_ids, args): super(PromptBartModel, self).__init__() self.use_prompt = args.use_prompt self.prompt_len = args.prompt_len self.prompt_dim = args.prompt_dim self.learn_weights = args.learn_weights self.device = 'cuda' if torch.cuda.is_available else 'cpu' bart_name = args.bart_name self.bart_config = BartConfig.from_pretrained(bart_name) self.bart_config.use_prompt = args.use_prompt self.bart_config.preseqlen = args.prompt_len bart_config = self.bart_config bart_model = BartModel.from_pretrained(bart_name, config=bart_config) num_tokens, _ = bart_model.encoder.embed_tokens.weight.shape bart_model.resize_token_embeddings(len(tokenizer.unique_no_split_tokens)+num_tokens) bart_model = avg_token_embeddings(tokenizer, bart_model, bart_name, num_tokens) self.prompt_encoder = PromptBartEncoder(bart_model.encoder) self.prompt_decoder = PromptBartDecoder(bart_model.decoder, tokenizer.pad_token_id, label_ids, self.use_prompt, self.prompt_len, self.learn_weights) self.prompt_inputs = torch.arange(self.prompt_len).long() self.encoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model) self.encoder_mlp = nn.Sequential( nn.Linear(bart_config.d_model, self.prompt_dim), nn.Tanh(), nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model)) self.decoder_prompt_embed = nn.Embedding(self.prompt_len, bart_config.d_model) self.decoder_mlp = nn.Sequential( nn.Linear(bart_config.d_model, self.prompt_dim), nn.Tanh(), nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model)) self.prompt_cross_embed = nn.Embedding(self.prompt_len, bart_config.d_model) self.cross_mlp = nn.Sequential( nn.Linear(bart_config.d_model, self.prompt_dim), nn.Tanh(), nn.Linear(self.prompt_dim, bart_config.decoder_layers * 2 * bart_config.d_model)) self.dropout = nn.Dropout(0.0)
[docs] def forward(self, src_tokens, tgt_tokens, src_seq_len, first): prompt_state = self.generator(src_tokens, src_seq_len, first) decoder_outputs, prompt_state = self.prompt_decoder(tgt_tokens, prompt_state) return decoder_outputs
[docs] def generator(self, src_tokens, src_seq_len, first): batch_size = src_tokens.size(0) past_key_values = self.get_prompt(batch_size) if self.use_prompt else None attention_mask = seq_to_mask(src_seq_len, max_len=src_tokens.size(1)) encoder_outputs, hidden_states = self.prompt_encoder(src_tokens, attention_mask=attention_mask, past_key_values=past_key_values) prompt_state = PromptBartState(encoder_outputs, attention_mask, past_key_values, src_tokens, first, hidden_states[0], self.bart_config.preseqlen) return prompt_state
[docs] def get_prompt(self, batch_size): input_tokens = self.prompt_inputs.unsqueeze(0).expand(batch_size, -1).to(self.device) # encoder prompt encoder_embed = self.encoder_prompt_embed(input_tokens) past_key_values = self.encoder_mlp(encoder_embed) #bsz, seqlen, layer*emb bsz, seqlen, _ = past_key_values.shape past_key_values = past_key_values.view(bsz, seqlen, self.bart_config.decoder_layers * 2, self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads) past_key_values = self.dropout(past_key_values) past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2) # key + value # decoder prompt decoder_embed = self.decoder_prompt_embed(input_tokens) past_key_values2 = self.decoder_mlp(decoder_embed) # bsz, seqlen, layer*emb past_key_values2 = past_key_values2.view(bsz, seqlen, self.bart_config.decoder_layers * 2, self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads) past_key_values2 = self.dropout(past_key_values2) past_key_values2 = past_key_values2.permute([2, 0, 3, 1, 4]).split(2) # cross prompt cross_embed = self.prompt_cross_embed(input_tokens) past_key_values_enc = self.cross_mlp(cross_embed) # bsz, seqlen, layer*emb past_key_values_enc = past_key_values_enc.view(bsz, seqlen, self.bart_config.decoder_layers * 2, self.bart_config.decoder_attention_heads, self.bart_config.d_model // self.bart_config.decoder_attention_heads) past_key_values_enc = self.dropout(past_key_values_enc) past_key_values_enc = past_key_values_enc.permute([2, 0, 3, 1, 4]).split(2) result = [] for i, key_val in enumerate(past_key_values): temp_dict = {'self': {"prev_key": key_val[0].contiguous(), "prev_value": key_val[1].contiguous(), "prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val.device).bool() #bsz, preseqlen }, } key_val2 = past_key_values2[i] temp_dict['encoder_decoder'] = {"prev_key": key_val2[0].contiguous(), "prev_value": key_val2[1].contiguous(), "prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val2.device).bool() } key_val_enc = past_key_values_enc[i] temp_dict['encoder'] = {"prev_key": key_val_enc[0].contiguous(), "prev_value": key_val_enc[1].contiguous(), "prev_key_padding_mask": torch.zeros(bsz, seqlen).to(key_val_enc.device).bool() } result.append(temp_dict) return result
[docs]class PromptBartState(object): def __init__(self, encoder_output, encoder_mask, past_key_values, src_tokens, first, src_embed_outputs, preseqlen): self.encoder_output = encoder_output self.encoder_mask = encoder_mask self.past_key_values = past_key_values self.src_tokens = src_tokens self.first = first self.src_embed_outputs = src_embed_outputs self.preseqlen = preseqlen def _reorder_state(self, state: Union[torch.Tensor, list, tuple], indices: torch.LongTensor, dim: int = 0): if isinstance(state, torch.Tensor): state = state.index_select(index=indices, dim=dim) elif isinstance(state, list): for i in range(len(state)): assert state[i] is not None state[i] = self._reorder_state(state[i], indices, dim) elif isinstance(state, tuple): tmp_list = [] for i in range(len(state)): assert state[i] is not None tmp_list.append(self._reorder_state(state[i], indices, dim)) state = tuple(tmp_list) else: raise TypeError(f"Cannot reorder data of type:{type(state)}") return state
[docs] def reorder_state(self, indices: torch.LongTensor): super().reorder_state(indices) self.src_tokens = self._reorder_state(self.src_tokens, indices) if self.first is not None: self.first = self._reorder_state(self.first, indices) self.src_embed_outputs = self._reorder_state(self.src_embed_outputs, indices) if self.past_key_values is not None: new = [] for layer in self.past_key_values: new_layer = {} for key1 in list(layer.keys()): new_layer_ = {} for key2 in list(layer[key1].keys()): if layer[key1][key2] is not None: layer[key1][key2] = self._reorder_state(layer[key1][key2], indices) new_layer_[key2] = layer[key1][key2] new_layer[key1] = new_layer_ new.append(new_layer) self.past_key_values = new
[docs] def num_samples(self): if self.encoder_output is not None: return self.encoder_output.size(0) else: return None
[docs]class PromptGeneratorModel(nn.Module): def __init__(self, prompt_model, max_length=20, max_len_a=0.0, num_beams=1, do_sample=False, bos_token_id=None, eos_token_id=None, repetition_penalty=1, length_penalty=1.0, pad_token_id=0, restricter=None): super(PromptGeneratorModel, self).__init__() self.prompt_model = prompt_model self.decoder = prompt_model.prompt_decoder self.generate_func = partial(greedy_generate, decoder=self.decoder, max_length=max_length, max_len_a=max_len_a, num_beams=num_beams, bos_token_id=bos_token_id, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id, restricter=restricter) self.do_sample = do_sample self.max_length = max_length self.num_beams = num_beams self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.repetition_penalty = repetition_penalty self.length_penalty = length_penalty self.pad_token_id = pad_token_id self.restricter = restricter self.max_len_a = max_len_a
[docs] def forward(self, src_tokens, tgt_tokens, src_seq_len=None, tgt_seq_len=None, first=None): """ :param torch.LongTensor src_tokens: bsz x max_len :param torch.LongTensor tgt_tokens: bsz x max_len' :param torch.LongTensor src_seq_len: bsz :param torch.LongTensor tgt_seq_len: bsz :return: """ return self.prompt_model(src_tokens, tgt_tokens, src_seq_len, first)
[docs] def predict(self, src_tokens, src_seq_len=None, first=None): """ :param torch.LongTensor src_tokens: bsz x max_len :param torch.LongTensor src_seq_len: bsz :return: """ prompt_state = self.prompt_model.generator(src_tokens, src_seq_len, first) # encoder output result = self.generate_func(tokens=None, state=prompt_state) return result
[docs]@torch.no_grad() def greedy_generate(decoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=1, bos_token_id=None, eos_token_id=None, pad_token_id=0, repetition_penalty=1, length_penalty=1.0, restricter=None): if num_beams == 1: token_ids = _no_beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, bos_token_id=bos_token_id, eos_token_id=eos_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id, restricter=restricter) else: token_ids = _beam_search_generate(decoder, tokens=tokens, state=state, max_length=max_length, max_len_a=max_len_a, num_beams=num_beams, bos_token_id=bos_token_id, eos_token_id=eos_token_id, do_sample=False, repetition_penalty=repetition_penalty, length_penalty=length_penalty, pad_token_id=pad_token_id, restricter=restricter) return token_ids
def _no_beam_search_generate(decoder: PromptBartDecoder, state, tokens=None, max_length=20, max_len_a=0.0, bos_token_id=None, eos_token_id=None, repetition_penalty=1.0, length_penalty=1.0, pad_token_id=0, restricter=None): device = get_model_device(decoder) if tokens is None: if bos_token_id is None: raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") batch_size = state.num_samples() if batch_size is None: raise RuntimeError("Cannot infer the number of samples from `state`.") tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) batch_size = tokens.size(0) if state.num_samples: assert state.num_samples() == batch_size, "The number of samples in `tokens` and `state` should match." if eos_token_id is None: _eos_token_id = -1 else: _eos_token_id = eos_token_id scores = decoder.decode(tokens=tokens, state=state) # update state if restricter is not None: _, next_tokens = restricter(state, tokens, scores, num_beams=1) else: next_tokens = scores.argmax(dim=-1, keepdim=True) token_ids = torch.cat([tokens, next_tokens], dim=1) cur_len = token_ids.size(1) dones = token_ids.new_zeros(batch_size).eq(1).__or__(next_tokens.squeeze(1).eq(eos_token_id)) # tokens = tokens[:, -1:] if max_len_a!=0: # (bsz x num_beams, ) if state.encoder_mask is not None: max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length else: max_lengths = tokens.new_full((tokens.size(0), ), fill_value=max_length, dtype=torch.long) real_max_length = max_lengths.max().item() else: real_max_length = max_length if state.encoder_mask is not None: max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length else: max_lengths = tokens.new_full((tokens.size(0),), fill_value=max_length, dtype=torch.long) while cur_len < real_max_length: scores = decoder.decode(tokens=token_ids, state=state) # batch_size x vocab_size if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() ge_zero_mask = lt_zero_mask.eq(0).float() token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores scores.scatter_(dim=1, index=token_ids, src=token_scores) if eos_token_id is not None and length_penalty != 1.0: token_scores = scores / cur_len ** length_penalty # batch_size x vocab_size eos_mask = scores.new_ones(scores.size(1)) eos_mask[eos_token_id] = 0 eos_mask = eos_mask.unsqueeze(0).eq(1) scores = scores.masked_scatter(eos_mask, token_scores) if restricter is not None: _, next_tokens = restricter(state, token_ids, scores, 1) else: next_tokens = scores.argmax(dim=-1, keepdim=True) next_tokens = next_tokens.squeeze(-1) if _eos_token_id!=-1: next_tokens = next_tokens.masked_fill(max_lengths.eq(cur_len+1), _eos_token_id) next_tokens = next_tokens.masked_fill(dones, pad_token_id) tokens = next_tokens.unsqueeze(1) token_ids = torch.cat([token_ids, tokens], dim=-1) # batch_size x max_len end_mask = next_tokens.eq(_eos_token_id) dones = dones.__or__(end_mask) cur_len += 1 if dones.min() == 1: break return token_ids def _beam_search_generate(decoder: PromptBartDecoder, tokens=None, state=None, max_length=20, max_len_a=0.0, num_beams=4, bos_token_id=None, eos_token_id=None, do_sample=True, repetition_penalty=1.0, length_penalty=None, pad_token_id=0, restricter=None) -> torch.LongTensor: assert do_sample is False # beam search device = get_model_device(decoder) if tokens is None: if bos_token_id is None: raise RuntimeError("You have to specify either `tokens` or `bos_token_id`.") batch_size = state.num_samples if batch_size is None: raise RuntimeError("Cannot infer the number of samples from `state`.") tokens = torch.full([batch_size, 1], fill_value=bos_token_id, dtype=torch.long).to(device) batch_size = tokens.size(0) if state.num_samples: assert state.num_samples == batch_size, "The number of samples in `tokens` and `state` should match." if eos_token_id is None: _eos_token_id = -1 else: _eos_token_id = eos_token_id scores = decoder.decode(tokens=tokens, state=state) vocab_size = scores.size(1) assert vocab_size >= num_beams, "num_beams should be smaller than the number of vocabulary size." scores = F.log_softmax(scores, dim=-1) # (batch_size, vocab_size) if restricter is not None: _next_scores, _next_tokens = restricter(state, tokens, scores, num_beams+1) else: # bsz x (num_beams+1) _next_scores, _next_tokens = torch.topk(scores, num_beams+1, dim=1, largest=True, sorted=True) indices = torch.arange(batch_size, dtype=torch.long).to(device) indices = indices.repeat_interleave(num_beams) state.reorder_state(indices) tokens = tokens.index_select(dim=0, index=indices) # batch_size * num_beams x length if max_len_a!=0: # (bsz x num_beams, ) if state.encoder_mask is not None: max_lengths = (state.encoder_mask.sum(dim=1).float()*max_len_a).long() + max_length else: max_lengths = tokens.new_full((batch_size*num_beams, ), fill_value=max_length, dtype=torch.long) real_max_length = max_lengths.max().item() else: real_max_length = max_length if state.encoder_mask is not None: max_lengths = state.encoder_mask.new_ones(state.encoder_mask.size(0)).long()*max_length else: max_lengths = tokens.new_full((batch_size*num_beams,), fill_value=max_length, dtype=torch.long) hypos = [ BeamHypotheses(num_beams, real_max_length, length_penalty, early_stopping=False) for _ in range(batch_size) ] not_eos_mask = _next_tokens.ne(_eos_token_id) keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) keep_mask = not_eos_mask.__and__(keep_mask) next_tokens = _next_tokens.masked_select(keep_mask).view(batch_size, num_beams) next_scores = _next_scores.masked_select(keep_mask).view(batch_size, num_beams) rows, cols = not_eos_mask.eq(0)[:, :num_beams].nonzero(as_tuple=True) if len(rows)>0: for row, col in zip(rows.tolist(), cols.tolist()): _token = torch.cat([tokens[row*num_beams], _next_tokens[row, col:col+1]], dim=0) hypos[row].add(_token.clone(), _next_scores[row, col].item()) # (batch_size, cur_len) token_ids = torch.cat([tokens, next_tokens.view(-1, 1)], dim=-1) dones = [False] * batch_size beam_scores = next_scores.view(-1) # batch_size * num_beams cur_len = token_ids.size(1) # 0, num_beams, 2*num_beams, ... batch_inds_with_numbeams_interval = (torch.arange(batch_size) * num_beams).view(-1, 1).to(token_ids) while cur_len < real_max_length: scores = decoder.decode(token_ids, state) # (bsz x num_beams, vocab_size) if repetition_penalty != 1.0: token_scores = scores.gather(dim=1, index=token_ids) lt_zero_mask = token_scores.lt(0).float() ge_zero_mask = lt_zero_mask.eq(0).float() token_scores = lt_zero_mask * repetition_penalty * token_scores + ge_zero_mask / repetition_penalty * token_scores scores.scatter_(dim=1, index=token_ids, src=token_scores) if _eos_token_id!=-1: max_len_eos_mask = max_lengths.eq(cur_len+1) eos_scores = scores[:, _eos_token_id] scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores) scores = F.log_softmax(scores, dim=-1) # (batch_size * num_beams, vocab_size) _scores = scores + beam_scores[:, None] # (batch_size * num_beams, vocab_size) _scores = _scores.view(batch_size, -1) # (batch_size, num_beams*vocab_size) if restricter is not None: next_scores, ids = restricter(state, token_ids, _scores, 2 * num_beams) else: next_scores, ids = torch.topk(_scores, 2 * num_beams, dim=1, largest=True, sorted=True) # (bsz, 2*num_beams) from_which_beam = ids // vocab_size # (batch_size, 2*num_beams) next_tokens = ids % vocab_size # (batch_size, 2*num_beams) not_eos_mask = next_tokens.ne(_eos_token_id) keep_mask = not_eos_mask.cumsum(dim=1).le(num_beams) keep_mask = not_eos_mask.__and__(keep_mask) _next_tokens = next_tokens.masked_select(keep_mask).view(-1, 1) _from_which_beam = from_which_beam.masked_select(keep_mask).view(batch_size, num_beams) _next_scores = next_scores.masked_select(keep_mask).view(batch_size, num_beams) beam_scores = _next_scores.view(-1) flag = True if cur_len+1 == real_max_length: eos_batch_idx = torch.arange(batch_size).to(next_tokens).repeat_interleave(repeats=num_beams, dim=0) eos_beam_ind = torch.arange(num_beams).to(token_ids).repeat(batch_size) eos_beam_idx = from_which_beam[:, :num_beams].reshape(-1) else: effective_eos_mask = next_tokens[:, :num_beams].eq(_eos_token_id) # batch_size x num_beams if effective_eos_mask.sum().gt(0): eos_batch_idx, eos_beam_ind = effective_eos_mask.nonzero(as_tuple=True) eos_beam_idx = eos_batch_idx * num_beams * 2 + eos_beam_ind eos_beam_idx = from_which_beam.view(-1)[eos_beam_idx] else: flag = False if flag: _token_ids = torch.cat([token_ids, _next_tokens], dim=-1) for batch_idx, beam_ind, beam_idx in zip(eos_batch_idx.tolist(), eos_beam_ind.tolist(), eos_beam_idx.tolist()): if not dones[batch_idx]: score = next_scores[batch_idx, beam_ind].item() if _eos_token_id!=-1: hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx, :cur_len].clone(), score) else: hypos[batch_idx].add(_token_ids[batch_idx * num_beams + beam_idx].clone(), score) reorder_inds = (batch_inds_with_numbeams_interval + _from_which_beam).view(-1) # flatten state.reorder_state(reorder_inds) token_ids = torch.cat([token_ids.index_select(index=reorder_inds, dim=0), _next_tokens], dim=-1) for batch_idx in range(batch_size): dones[batch_idx] = dones[batch_idx] or hypos[batch_idx].is_done(next_scores[batch_idx, 0].item()) or \ max_lengths[batch_idx*num_beams]==cur_len+1 cur_len += 1 if all(dones): break # select the best hypotheses tgt_len = token_ids.new_zeros(batch_size) best = [] for i, hypotheses in enumerate(hypos): best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1] if _eos_token_id!=-1: best_hyp = torch.cat([best_hyp, best_hyp.new_ones(1)*_eos_token_id]) tgt_len[i] = len(best_hyp) best.append(best_hyp) # generate target batch decoded = token_ids.new_zeros(batch_size, tgt_len.max().item()).fill_(pad_token_id) for i, hypo in enumerate(best): decoded[i, :tgt_len[i]] = hypo return decoded
[docs]class BeamHypotheses(object): def __init__(self, num_beams, max_length, length_penalty, early_stopping): """ Initialize n-best list of hypotheses. """ self.max_length = max_length - 1 # ignoring bos_token self.length_penalty = length_penalty self.early_stopping = early_stopping self.num_beams = num_beams self.hyp = [] self.worst_score = 1e9 def __len__(self): """ Number of hypotheses in the list. """ return len(self.hyp)
[docs] def add(self, hyp, sum_logprobs): """ Add a new hypothesis to the list. """ score = sum_logprobs / len(hyp) ** self.length_penalty if len(self) < self.num_beams or score > self.worst_score: self.hyp.append((score, hyp)) if len(self) > self.num_beams: sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.hyp)]) del self.hyp[sorted_scores[0][1]] self.worst_score = sorted_scores[1][0] else: self.worst_score = min(score, self.worst_score)
[docs] def is_done(self, best_sum_logprobs): """ If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst one in the heap, then we are done with this sentence. """ if len(self) < self.num_beams: return False elif self.early_stopping: return True else: return self.worst_score >= best_sum_logprobs / self.max_length ** self.length_penalty