| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch |
| from torch.autograd import Variable |
| import copy |
| class Seq2Seq(nn.Module): |
| """ |
| Build Seqence-to-Sequence. |
| |
| Parameters: |
| |
| * `encoder`- encoder of seq2seq model. e.g. roberta |
| * `decoder`- decoder of seq2seq model. e.g. transformer |
| * `config`- configuration of encoder model. |
| * `beam_size`- beam size for beam search. |
| * `max_length`- max length of target for beam search. |
| * `sos_id`- start of symbol ids in target for beam search. |
| * `eos_id`- end of symbol ids in target for beam search. |
| """ |
| def __init__(self, encoder,decoder, config, beam_size=None, max_length=None, sos_id=None, eos_id=None): |
| super(Seq2Seq, self).__init__() |
| self.encoder = encoder |
| self.decoder=decoder |
| self.config=config |
| self.register_buffer( |
| "bias", torch.tril(torch.ones((1024, 1024), dtype=torch.uint8)).view(1,1024, 1024) |
| ) |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.lm_head.weight = self.encoder.embeddings.word_embeddings.weight |
| self.lsm = nn.LogSoftmax(dim=-1) |
| |
| self.beam_size = beam_size |
| self.max_length = max_length |
| self.sos_id = sos_id |
| self.eos_id = eos_id |
| |
| def forward(self, source_ids, target_ids=None): |
| if target_ids is None: |
| return self.generate(source_ids) |
| |
| mask = source_ids.ne(1)[:,None,:]*source_ids.ne(1)[:,:,None] |
| encoder_output = self.encoder(source_ids,attention_mask=mask,use_cache=True) |
| ids = torch.cat((source_ids,target_ids),-1) |
| mask = self.bias[:,source_ids.size(-1):ids.size(-1),:ids.size(-1)].bool() |
| mask = mask & ids[:,None,:].ne(1) |
|
|
| out = self.decoder(target_ids,attention_mask=mask,past_key_values=encoder_output.past_key_values).last_hidden_state |
| lm_logits = self.lm_head(out) |
| |
| active_loss = target_ids[..., 1:].ne(1).view(-1) |
| shift_logits = lm_logits[..., :-1, :].contiguous() |
| shift_labels = target_ids[..., 1:].contiguous() |
| |
| loss_fct = nn.CrossEntropyLoss(ignore_index=-1) |
| loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1))[active_loss], |
| shift_labels.view(-1)[active_loss]) |
|
|
| outputs = loss,loss*active_loss.sum(),active_loss.sum() |
| return outputs |
| |
| def generate(self, source_ids): |
| mask = source_ids.ne(1)[:,None,:]*source_ids.ne(1)[:,:,None] |
| encoder_output = self.encoder(source_ids,attention_mask=mask,use_cache=True) |
| preds = [] |
| zero = torch.cuda.LongTensor(1).fill_(0) |
| source_len = list(source_ids.ne(1).sum(-1).cpu().numpy()) |
| for i in range(source_ids.shape[0]): |
| context = [[x[i:i+1,:,:source_len[i]].repeat(self.beam_size,1,1,1) for x in y] |
| for y in encoder_output.past_key_values] |
| beam = Beam(self.beam_size,self.sos_id,self.eos_id) |
| input_ids = beam.getCurrentState() |
| context_ids = source_ids[i:i+1,:source_len[i]].repeat(self.beam_size,1) |
| for _ in range(self.max_length): |
| if beam.done(): |
| break |
|
|
| ids = torch.cat((context_ids,input_ids),-1) |
| mask = self.bias[:,context_ids.size(-1):ids.size(-1),:ids.size(-1)].bool() |
| mask = mask & ids[:,None,:].ne(1) |
| out = self.decoder(input_ids,attention_mask=mask,past_key_values=context).last_hidden_state |
| hidden_states = out[:,-1,:] |
| out = self.lsm(self.lm_head(hidden_states)).data |
| beam.advance(out) |
| input_ids.data.copy_(input_ids.data.index_select(0, beam.getCurrentOrigin())) |
| input_ids = torch.cat((input_ids,beam.getCurrentState()),-1) |
| hyp = beam.getHyp(beam.getFinal()) |
| pred = beam.buildTargetTokens(hyp)[:self.beam_size] |
| pred = [torch.cat([x.view(-1) for x in p]+[zero]*(self.max_length-len(p))).view(1,-1) for p in pred] |
| preds.append(torch.cat(pred,0).unsqueeze(0)) |
|
|
| preds = torch.cat(preds,0) |
|
|
| return preds |
| |
| |
|
|
| class Beam(object): |
| def __init__(self, size,sos,eos): |
| self.size = size |
| self.tt = torch.cuda |
| |
| self.scores = self.tt.FloatTensor(size).zero_() |
| |
| self.prevKs = [] |
| |
| self.nextYs = [self.tt.LongTensor(size) |
| .fill_(0)] |
| self.nextYs[0][0] = sos |
| |
| self._eos = eos |
| self.eosTop = False |
| |
| self.finished = [] |
|
|
| def getCurrentState(self): |
| "Get the outputs for the current timestep." |
| batch = self.tt.LongTensor(self.nextYs[-1]).view(-1, 1) |
| return batch |
|
|
| def getCurrentOrigin(self): |
| "Get the backpointers for the current timestep." |
| return self.prevKs[-1] |
|
|
| def advance(self, wordLk): |
| """ |
| Given prob over words for every last beam `wordLk` and attention |
| `attnOut`: Compute and update the beam search. |
| |
| Parameters: |
| |
| * `wordLk`- probs of advancing from the last step (K x words) |
| * `attnOut`- attention at the last step |
| |
| Returns: True if beam search is complete. |
| """ |
| numWords = wordLk.size(1) |
|
|
| |
| if len(self.prevKs) > 0: |
| beamLk = wordLk + self.scores.unsqueeze(1).expand_as(wordLk) |
|
|
| |
| for i in range(self.nextYs[-1].size(0)): |
| if self.nextYs[-1][i] == self._eos: |
| beamLk[i] = -1e20 |
| else: |
| beamLk = wordLk[0] |
| flatBeamLk = beamLk.view(-1) |
| bestScores, bestScoresId = flatBeamLk.topk(self.size, 0, True, True) |
|
|
| self.scores = bestScores |
|
|
| |
| |
| prevK = bestScoresId // numWords |
| self.prevKs.append(prevK) |
| self.nextYs.append((bestScoresId - prevK * numWords)) |
|
|
|
|
| for i in range(self.nextYs[-1].size(0)): |
| if self.nextYs[-1][i] == self._eos: |
| s = self.scores[i] |
| self.finished.append((s, len(self.nextYs) - 1, i)) |
|
|
| |
| if self.nextYs[-1][0] == self._eos: |
| self.eosTop = True |
|
|
| def done(self): |
| return self.eosTop and len(self.finished) >=self.size |
|
|
| def getFinal(self): |
| if len(self.finished) == 0: |
| self.finished.append((self.scores[0], len(self.nextYs) - 1, 0)) |
| self.finished.sort(key=lambda a: -a[0]) |
| if len(self.finished) != self.size: |
| unfinished=[] |
| for i in range(self.nextYs[-1].size(0)): |
| if self.nextYs[-1][i] != self._eos: |
| s = self.scores[i] |
| unfinished.append((s, len(self.nextYs) - 1, i)) |
| unfinished.sort(key=lambda a: -a[0]) |
| self.finished+=unfinished[:self.size-len(self.finished)] |
| return self.finished[:self.size] |
|
|
| def getHyp(self, beam_res): |
| """ |
| Walk back to construct the full hypothesis. |
| """ |
| hyps=[] |
| for _,timestep, k in beam_res: |
| hyp = [] |
| for j in range(len(self.prevKs[:timestep]) - 1, -1, -1): |
| hyp.append(self.nextYs[j+1][k]) |
| k = self.prevKs[j][k] |
| hyps.append(hyp[::-1]) |
| return hyps |
| |
| def buildTargetTokens(self, preds): |
| sentence=[] |
| for pred in preds: |
| tokens = [] |
| for tok in pred: |
| if tok==self._eos: |
| break |
| tokens.append(tok) |
| sentence.append(tokens) |
| return sentence |
| |
|
|