LLM - 加速 : Prompt Lookup Decode Coding

本文在 Colab L4 GPU with 22GB DRAM 的 GPU 跑 Mistral-7B Model 改成 Microsoft Phi3-mini model 使用 prompt lookup decode 方法。

本文延續 [[2024-08-19-Prompt_Lookup_Decode]] 聚集在 coding. 有三種不同層次的 coding 整理如下:

  Orig PLD HF Candidate HF Xformers
最底層 中間層 最上層
Call find_candidate_pred_tokens greedy_search_pld model.generate
使用時機 改變 retrieval 方法 不改 retrieval 修改 speculative or retrieval 不修改直接用
目的 Speed up Robustness Easiness

Some observations and issues:

  • Transformer prompt_lookup_num_tokens shows the best speed-up across different GPUs.
  • There is negligible difference between your original code and my revised version using the CandidateGenerator class. There might be something wrong. I would appreciate it if you could take some time to check it. :)
  • The Colab GPU T4 seems to behave strangely with both the original and revised versions, but it behaves normally in the Transformer prompt_lookup_num_tokens version. I’m not sure why.
GPU OLD Orig PLD, t/s Speed up Candidate Generator, t/s Speed up Xformers PLD, t/s Speed up
Colab L4 On 33.9 2.2 33.6 2.2 33.6 2.3
Colab L4 Off 15.2   15.2   14.3  
Colab T4 On 11.1 0.7 14.3 1.0 33.2 2.2
Colab T4 Off 15.6   14.6   14.9  
RTX3060 On 24.7 2.0 26.4 2.0 31.8 2.5
RTX3060 Off 12.4   13.2   12.8  

Prompt Lookup Decode Code Review

不論是觀念還是做法上,這是最簡單但是有效的一種方法。 因為非常簡單,很多 model hub 都已經整合 prompt lookup decode 在 model 中。

Find_Candidate_Pre_Tokens

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
    input_length = input_ids.size(1)

    # Ensure max_ngram_size and num_pred_tokens are valid
    if max_ngram_size <= 0 or num_pred_tokens <= 0 or max_ngram_size > input_length:
        raise ValueError("Invalid max_ngram_size or num_pred_tokens")

    for ngram_size in range(max_ngram_size, 0, -1):
        # Extract the last n tokens as our search ngram
        ngram = input_ids[0, -ngram_size:].tolist()

        # Create sliding windows of size ngram_size
        windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
        
        # Convert ngram to a tensor for comparison
        ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)
        
        # Find where the windows match the ngram
        matches = (windows == ngram_tensor).all(dim=2)

        # Get the indices of matches
        match_indices = matches.nonzero(as_tuple=True)[1]

        # Iterate through match indices to find a valid continuation
        for idx in match_indices:
            start_idx = idx + ngram_size
            end_idx = start_idx + num_pred_tokens
            # Ensure we don't go beyond the length of input_ids and avoid self-match
            if end_idx <= input_length and start_idx < input_length - ngram_size:
                return input_ids[0, start_idx:end_idx]

    # If no match is found, return an empty tensor
    return torch.tensor([], dtype=torch.long, device=input_ids.device)

  
COLORS = ["\x1b[31m", "\x1b[32m", "\x1b[34m", "\x1b[35m"]  # Red, Green, Blue, Magenta
UNDERLINE = "\x1b[4m"
RESET = "\x1b[0m"

這段代碼定義了一個名為 find_candidate_pred_tokens 的函數,旨在在給定的標記 ID 序列 (input_ids) 中搜索特定序列(n-gram),並返回找到的序列之後的標記。

以下是代碼的詳解:

函數目的:

find_candidate_pred_tokens 函數的目的是識別 input_ids 中的一個標記序列(即 n-gram),並返回緊接著該序列的標記。該函數嘗試尋找最大的 n-gram(最多到 max_ngram_size),並返回一組跟隨識別出的 n-gram 的預測標記(num_pred_tokens)。

參數:

  • input_ids: 一個包含標記 ID 序列的 2D 張量。該函數假設第一維表示批次(儘管它僅適用於批次大小為 1),第二維表示序列長度。

  • max_ngram_size: 在 input_ids 中搜索的 n-gram 的最大大小。該函數將首先查找這個長度的序列,然後逐步減小大小直到找到匹配。

  • num_pred_tokens: 在找到 n-gram 後要返回的標記數量。

代碼解釋:

  1. 輸入驗證

    • 函數首先驗證 max_ngram_sizenum_pred_tokens 是否為正值,以及 max_ngram_size 是否不大於 input_ids 的長度。如果違反了其中任何一個條件,則函數會引發 ValueError.
  2. 主循環

    • 然後函數進入一個循環,從 max_ngram_size 開始向下迭代到 1。此循環旨在尋找與 input_ids 中的一個序列匹配的最大的 n-gram。

    • n-gram 提取:對於每個 ngram_size, 它從 input_ids 中提取最後的 ngram_size 個標記並將這些標記轉換為一個張量 (ngram_tensor)。

    • 滑動窗口:函數使用 input_ids.unfold(dimension=1, size=ngram_size, step=1) 創建與 ngram_size 相同大小的“滑動窗口”標記。這會從 input_ids 生成重疊窗口(子序列)。

    • 匹配查找:然後檢查這些窗口中的哪一個與提取出的 ngram_tensor 匹配。結果存儲在布林值張量 matches 中。

    • 匹配索引:如果找到匹配,則將這些匹配的位置索引存儲在 match_indices.

  3. 返回預測標記

    • 函數遍歷已找到的匹配索引,以識別 n-gram 的有效延續。它計算跟隨匹配 n-gram 的標記起始和結束索引(即 start_idxend_idx)。

    • 如果延續落在 input_ids 的範圍內且不與 n-gram 本身重疊,它將返回預測標記(從 `start_idx 到 end_idx)。

  4. 無匹配情況

    • 如果嘗試所有 n-gram 大小後仍未找到任何匹配,則該函數將返回一個空張量。

顏色和下劃線:

在代碼結尾處,有一些常量被定義為顏色代碼和下劃線(例如,『COLORS』和『UNDERLINE』)。這些未被用於函數中,看起來是用來格式化終端文本,但對於『find_candidate_pred_tokens』函數而言並不相關。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def greedy_search_pld(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        draft_matching_window_size = 3,
        draft_num_candidate_tokens = 10,
        print_output=True,
        **model_kwargs,
    ):

 
        global tokenizer

        # init values
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
  
        # # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        max_len = stopping_criteria[0].max_length
        i = 0
        current_color_index = 0

        while True:
            i += 1
            cur_len = input_ids.shape[-1]
            candidate_pred_tokens = find_candidate_pred_tokens(input_ids, draft_matching_window_size, draft_num_candidate_tokens)
            if len(candidate_pred_tokens) == 0:
                candidate_pred_tokens = torch.tensor([100], device=input_ids.device).unsqueeze(0)
            else:
                candidate_pred_tokens = candidate_pred_tokens.unsqueeze(0)

            candidate_input_ids = torch.cat((input_ids, candidate_pred_tokens), dim=1)
            candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
            candidate_kwargs = copy.copy(model_kwargs)

            # If you need to extend attention mask and token type ids,
            # you'll need to implement the logic for these methods based on your model's requirements.
            # For instance, for attention mask you might do:
            # if "attention_mask" in candidate_kwargs:
            #     candidate_kwargs["attention_mask"] = torch.cat(
            #         [candidate_kwargs["attention_mask"], torch.ones((candidate_kwargs["attention_mask"].shape[0], candidate_length), dtype=torch.long, device=candidate_kwargs["attention_mask"].device)],
            #         dim=1,
            #     )

            candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]). ## change to prepare_attention_mask!
            candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
            model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

            # prepare model inputs
            # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            new_logits = outputs.logits[:, -candidate_length - 1 :]  # excludes the input prompt if present
            selected_tokens = new_logits.argmax(dim=-1)
            candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
            n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

            # if last_assistant_token_is_eos and n_matches == candidate_length: # todo: do this earlier somehow
            #     n_matches -= 1
            n_matches = min(n_matches, max_len - cur_len - 1)
            # print(n_matches)
            # i+= n_matches.item()

            if print_output:
                current_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
            valid_tokens = selected_tokens[:, : n_matches + 1]
            input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
            new_cur_len = input_ids.shape[-1]

            if print_output:
                updated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
                # Find and print the newly added text
                if updated_text != current_text:
                    new_text = updated_text[len(current_text):]
                    if len(valid_tokens[0]) > 1:
                        color = COLORS[current_color_index]
                        print(f"{color}{new_text}{RESET}", end='')
                        # Update color for next generation
                        current_color_index = (current_color_index + 1) % len(COLORS)
                    else:
                        print(f"{new_text}", end='')

            new_cache_size = new_cur_len - 1
            outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
            model_kwargs["past_key_values"] = outputs.past_key_values
            # stop if we exceed the maximum length
            if (valid_tokens == eos_token_id_tensor.item()).any():
                break

            if stopping_criteria(input_ids, scores):
                break

        if return_dict_in_generate:
            return GreedySearchDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                # attentions=decoder_attentions,
                # hidden_states=decoder_hidden_states,
            )
        else:
            return input_ids

這段代碼定義了一個名為 find_candidate_pred_tokens 的函數,旨在在給定的標記 ID 序列 (input_ids) 中搜索特定序列(n-gram),並返回找到的序列之後的標記。

以下是代碼的詳解:

函數目的:

find_candidate_pred_tokens 函數的目的是識別 input_ids 中的一個標記序列(即 n-gram),並返回緊接著該序列的標記。該函數嘗試尋找最大的 n-gram(最多到 max_ngram_size),並返回一組跟隨識別出的 n-gram 的預測標記(num_pred_tokens)。

參數:

  • input_ids: 一個包含標記 ID 序列的 2D 張量。該函數假設第一維表示批次(儘管它僅適用於批次大小為 1),第二維表示序列長度。

  • max_ngram_size: 在 input_ids 中搜索的 n-gram 的最大大小。該函數將首先查找這個長度的序列,然後逐步減小大小直到找到匹配。

  • num_pred_tokens: 在找到 n-gram 後要返回的標記數量。

代碼解釋:

  1. 輸入驗證

    • 函數首先驗證 max_ngram_sizenum_pred_tokens 是否為正值,以及 max_ngram_size 是否不大於 input_ids 的長度。如果違反了其中任何一個條件,則函數會引發 ValueError.
  2. 主循環

    • 然後函數進入一個循環,從 max_ngram_size 開始向下迭代到 1。此循環旨在尋找與 input_ids 中的一個序列匹配的最大的 n-gram。

    • n-gram 提取:對於每個 ngram_size, 它從 input_ids 中提取最後的 ngram_size 個標記並將這些標記轉換為一個張量 (ngram_tensor)。

    • 滑動窗口:函數使用 input_ids.unfold(dimension=1, size=ngram_size, step=1) 創建與 ngram_size 相同大小的“滑動窗口”標記。這會從 input_ids 生成重疊窗口(子序列)。

    • 匹配查找:然後檢查這些窗口中的哪一個與提取出的 ngram_tensor 匹配。結果存儲在布林值張量 matches 中。

    • 匹配索引:如果找到匹配,則將這些匹配的位置索引存儲在 match_indices.

  3. 返回預測標記

    • 函數遍歷已找到的匹配索引,以識別 n-gram 的有效延續。它計算跟隨匹配 n-gram 的標記起始和結束索引(即 start_idxend_idx)。

    • 如果延續落在 input_ids 的範圍內且不與 n-gram 本身重疊,它將返回預測標記(從 `start_idx 到 end_idx)。

  4. 無匹配情況

    • 如果嘗試所有 n-gram 大小後仍未找到任何匹配,則該函數將返回一個空張量。

顏色和下劃線:

在代碼結尾處,有一些常量被定義為顏色代碼和下劃線(例如,『COLORS』和『UNDERLINE』)。這些未被用於函數中,看起來是用來格式化終端文本,但對於『find_candidate_pred_tokens』函數而言並不相關。

Greedy_Search_Pld Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def greedy_search_pld(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        draft_matching_window_size = 3,
        draft_num_candidate_tokens = 10,
        print_output=True,
        **model_kwargs,
    ):

 
        global tokenizer

        # init values
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
  
        # # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        max_len = stopping_criteria[0].max_length
        i = 0
        current_color_index = 0

        while True:
            i += 1
            cur_len = input_ids.shape[-1]
            candidate_pred_tokens = find_candidate_pred_tokens(input_ids, draft_matching_window_size, draft_num_candidate_tokens)
            if len(candidate_pred_tokens) == 0:
                candidate_pred_tokens = torch.tensor([100], device=input_ids.device).unsqueeze(0)
            else:
                candidate_pred_tokens = candidate_pred_tokens.unsqueeze(0)

            candidate_input_ids = torch.cat((input_ids, candidate_pred_tokens), dim=1)
            candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
            candidate_kwargs = copy.copy(model_kwargs)

            # If you need to extend attention mask and token type ids,
            # you'll need to implement the logic for these methods based on your model's requirements.
            # For instance, for attention mask you might do:
            # if "attention_mask" in candidate_kwargs:
            #     candidate_kwargs["attention_mask"] = torch.cat(
            #         [candidate_kwargs["attention_mask"], torch.ones((candidate_kwargs["attention_mask"].shape[0], candidate_length), dtype=torch.long, device=candidate_kwargs["attention_mask"].device)],
            #         dim=1,
            #     )

            candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]). ## change to prepare_attention_mask!
            candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
            model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

            # prepare model inputs
            # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            new_logits = outputs.logits[:, -candidate_length - 1 :]  # excludes the input prompt if present
            selected_tokens = new_logits.argmax(dim=-1)
            candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
            n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

            # if last_assistant_token_is_eos and n_matches == candidate_length: # todo: do this earlier somehow
            #     n_matches -= 1
            n_matches = min(n_matches, max_len - cur_len - 1)
            # print(n_matches)
            # i+= n_matches.item()

            if print_output:
                current_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
            valid_tokens = selected_tokens[:, : n_matches + 1]
            input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
            new_cur_len = input_ids.shape[-1]

            if print_output:
                updated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
                # Find and print the newly added text
                if updated_text != current_text:
                    new_text = updated_text[len(current_text):]
                    if len(valid_tokens[0]) > 1:
                        color = COLORS[current_color_index]
                        print(f"{color}{new_text}{RESET}", end='')
                        # Update color for next generation
                        current_color_index = (current_color_index + 1) % len(COLORS)
                    else:
                        print(f"{new_text}", end='')

            new_cache_size = new_cur_len - 1
            outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
            model_kwargs["past_key_values"] = outputs.past_key_values
            # stop if we exceed the maximum length
            if (valid_tokens == eos_token_id_tensor.item()).any():
                break

            if stopping_criteria(input_ids, scores):
                break

        if return_dict_in_generate:
            return GreedySearchDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                # attentions=decoder_attentions,
                # hidden_states=decoder_hidden_states,
            )
        else:
            return input_ids

這段代碼定義了一個名為 greedy_search_pld 的方法,用於使用自定義的貪婪搜索策略生成文本。此方法結合了一個草稿預測機制,它試圖將候選標記與模型的預測進行匹配,並以顏色編碼的格式輸出以便視覺化。

參數:

  • input_ids:表示輸入提示的標記 ID 序列。
  • logits_processor: (可選)用於修改 logits 的處理器列表(例如,過濾某些標記)。
  • stopping_criteria: (可選)決定何時停止生成的標準列表。
  • max_length: (可選)生成序列的最大長度。
  • pad_token_id: (可選)填充標記的 ID。
  • eos_token_id: (可選)結束序列標記的 ID。可以是一個 ID 或一個 ID 列表。
  • output_attentions: (可選)是否返回注意力分數。
  • output_hidden_states: (可選)是否返回隱藏狀態。
  • output_scores: (可選)是否返回預測的分數。
  • return_dict_in_generate: (可選)是否返回包含詳細生成信息的字典。
  • synced_gpus: 是否在生成過程中同步 GPU(對於多 GPU 設置)。
  • streamer: (可選)在生成過程中流式傳輸標記輸出的流媒介。
  • draft_matching_window_size: 用於匹配候選標記(n 元組)的窗口大小。
  • draft_num_candidate_tokens : 在預測下一個標記時要考慮的標記數量。
  • ** print_output ** : 一個布爾值,指示是否以顏色編碼打印生成的輸出。
  • ** **model_kwargs ** : 模型的其他參數。

主要步驟:

  1. 初始化

    • 函數首先根據提供的值或模型配置初始化 stopping_criteria, pad_token_id, 和 eos_token_id.
    • 如果提供了,則將 eos_token_id 轉換為張量。
  2. token生成循環

    • 方法進入一個循環,每次生成一個或一小組token,直到滿足停止條件為止。
    • 使用 find_candidate_pred_tokens 函數在輸入 (input_ids) 中搜索token序列並預測下一個可能的token (candidate_pred_tokens)。
    • 如果未找到匹配項,則默認預測 ID 為 100.
  3. 準備候選token

    • 將候選token附加到 input_ids, 並在擴展序列上運行模型,以生成下一個token logits (new_logits)。
    • 然後從 logits 中選擇最可能的token (selected_tokens)。
  4. 匹配預測token

    • 方法比較 selected_tokens 和 candidate tokens 並確定有多少匹配 (n_matches`) 。
    • 然後只選擇有效匹配token (valid_tokens) 附加到 input_ids.
  5. 打印輸出

    • 如果啟用了 print_output, 方法將以顏色編碼段打印新生成文本。每段新生成文本都會獲得不同顏色以便視覺化。
  6. 停止條件

    • 循環檢查生成的token是否包含 `eos_token_id 或者停止條件是否滿足。如果任一條件為真,則循環中斷,並停止生成。
  7. 返回結果

    • 如果設置了 `return_dict_in_generate 为 True, 方法將返回一個字典(GreedySearchDecoderOnlyOutput),其中包含生成序列和任何請求的附加數據(例如分數)。
    • 如果沒有,它僅返回生成序列(input_ids)。

主要特徵:

  • 貪婪搜索

CandidateGenerator Class

HuggingFace Transformer

HuggingFace 已經把 prompt lookup decode 整合到 transformer 的 option. 就是 prompt_lookup_num_tokens. Default 是 0 (off).

generation_output = model.generate(**input_ids, do_sample=False, max_new_tokens=512, streamer=streamer, prompt_lookup_num_tokens=10)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
import torch
 
model_name_or_path = "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"
 
tokenizer = AutoTokenizer.from_pretrained("teknium/OpenHermes-2.5-Mistral-7B")
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map="auto")

chat = [
    {
        "role": "system",
        "content": "You are Hermes 2, an unbiased AI assistant, and you always answer to the best of your ability."
    },
    {
        "role": "user",
        "content": (
            "You are given a partial and unparsed scientific article, please read it carefully and complete the "
            f"request below.{article}Please summarize the article in 5 sentences."
        )
    },
]
processed_chat = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
input_ids = tokenizer(processed_chat, return_tensors='pt').to(model.device)
 
streamer = TextStreamer(tokenizer)

# baseline without prompt lookup decode
# generation_output = model.generate(**input_ids, do_sample=False, max_new_tokens=512, streamer=streamer)

# with prompt lookup decode
generation_output = model.generate(**input_ids, do_sample=False, max_new_tokens=512, streamer=streamer, prompt_lookup_num_tokens=10)

max_memory = torch.cuda.max_memory_allocated(model.device)
print("Max memory (MB): ", max_memory * 1e-6)
new_tokens = generation_output.shape[1] - input_ids.input_ids.shape[1]
print("Throughput (tokens/sec): ", new_tokens / (start_event.elapsed_time(end_event) * 1.0e-3)) ```


### Result A
* Mixtrial-7B:  4.15 GB (model INT4 precision + kv cache)
* GPU L4,  22.5GB.
* File: demo_work.py
* Input prompt: xxx

|                         | w/o PLD | w/ PLD | Comment  |
| ----------------------- | ------- | ------ | -------- |
| Max Memory (MB)         | 6386    | 6389   | the same |
| Throughput (tokens/sec) | 26.44   | 10.56  | 2.5X     |

### Result B

* Phi3.5-mini-3.8B:  7.6 GB (model FP16 precision)
* GPU: T4, L4, RTX3060
* File: git_hub.ipynb, git_hub.py
* Input prompt: xxx

  

| GPU             | `use_new_generate = True` | `use_new_generate = False` | True and disable print_output |
| --------------- | ------------------------- | -------------------------- | ----------------------------- |
| Colab GPU T4    | 11 tokens/sec             | 15 tokens/sec              |                               |
| Colab GPU L4    | 32 tokens/sec             | 15.2 tokens/sec            | 33.8 tokens/sec               |
| Desktop RTX3060 | 24 tokens/sec             | 13 tokens/sec              |                               |

Results for 
use_new_generate = False: Total time: 278.8662 seconds over 10 runs Average Tokens per second: 15.1686 tokens/sec Total tokens generated across all runs: 4230 Results for 

use_new_generate = True: Total time: 147.6660 seconds over 10 runs Average Tokens per second: 33.8602 tokens/sec Total tokens generated across all runs: 5000


### Llama.cpp

https://github.com/ggerganov/llama.cpp/tree/master/examples/lookup



## 原理和 implementation




## Appendix A :  English Version of Code

## Prompt Lookup Decode


```python
@torch.no_grad()

def find_candidate_pred_tokens(input_ids, max_ngram_size=3, num_pred_tokens=10):
    input_length = input_ids.size(1)

    # Ensure max_ngram_size and num_pred_tokens are valid
    if max_ngram_size <= 0 or num_pred_tokens <= 0 or max_ngram_size > input_length:
        raise ValueError("Invalid max_ngram_size or num_pred_tokens")

    for ngram_size in range(max_ngram_size, 0, -1):
        # Extract the last n tokens as our search ngram
        ngram = input_ids[0, -ngram_size:].tolist()

        # Create sliding windows of size ngram_size
        windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
        
        # Convert ngram to a tensor for comparison
        ngram_tensor = torch.tensor(ngram, device=input_ids.device).unsqueeze(0)
        
        # Find where the windows match the ngram
        matches = (windows == ngram_tensor).all(dim=2)

        # Get the indices of matches
        match_indices = matches.nonzero(as_tuple=True)[1]

        # Iterate through match indices to find a valid continuation
        for idx in match_indices:
            start_idx = idx + ngram_size
            end_idx = start_idx + num_pred_tokens
            # Ensure we don't go beyond the length of input_ids and avoid self-match
            if end_idx <= input_length and start_idx < input_length - ngram_size:
                return input_ids[0, start_idx:end_idx]

    # If no match is found, return an empty tensor
    return torch.tensor([], dtype=torch.long, device=input_ids.device)

  
COLORS = ["\x1b[31m", "\x1b[32m", "\x1b[34m", "\x1b[35m"]  # Red, Green, Blue, Magenta
UNDERLINE = "\x1b[4m"
RESET = "\x1b[0m"

This code defines a function called find_candidate_pred_tokens that is designed to search for specific sequences (n-grams) within a given sequence of token IDs (input_ids) and return the following tokens after the found sequence. The function is decorated with @torch.no_grad(), which means that PyTorch won’t track operations for gradient calculation, saving memory and computation during inference.

Here’s a breakdown of the code:

Function Purpose:

The purpose of the find_candidate_pred_tokens function is to identify a sequence of tokens (an n-gram) within input_ids and return the tokens that immediately follow this sequence. The function tries to find the largest possible n-gram (up to max_ngram_size) and returns a set of predicted tokens (num_pred_tokens) that follow the identified n-gram.

Parameters:

  • input_ids: A 2D tensor containing sequences of token IDs. The function assumes the first dimension represents the batch (though it only works with a batch size of 1), and the second dimension represents the sequence length.

  • max_ngram_size: The maximum size of the n-gram to search for within input_ids. The function will look for sequences of this length first, then reduce the size until it finds a match.

  • num_pred_tokens: The number of tokens to return after the n-gram is found.

Code Explanation:

  1. Input Validation:

    • The function starts by validating that max_ngram_size and num_pred_tokens are positive and that max_ngram_size is not larger than the length of input_ids. If any of these conditions are violated, the function raises a ValueError.
  2. Main Loop:

    • The function then enters a loop, iterating from the max_ngram_size down to 1. The purpose of this loop is to find the largest n-gram that matches a sequence in input_ids.

    • ngram Extraction: For each ngram_size, it extracts the last ngram_size tokens from input_ids and converts this list of tokens into a tensor (ngram_tensor).

    • Sliding Windows: The function creates “sliding windows” of tokens of the same size as ngram_size across the input_ids using input_ids.unfold(dimension=1, size=ngram_size, step=1). This generates overlapping windows (subsequences) from the input_ids.

    • Match Finding: It then checks which of these windows match the extracted ngram by comparing each window with the ngram_tensor. The result is stored in matches, a tensor of boolean values.

    • Match Indices: If matches are found, the indices of these matches are stored in match_indices.

  3. Returning Predicted Tokens:

    • The function iterates over the found match indices to identify a valid continuation of the n-gram. It computes the starting and ending indices (start_idx and end_idx) for the tokens following the matched n-gram.

    • If the continuation falls within the bounds of input_ids and doesn’t overlap with the n-gram itself, it returns the predicted tokens (from start_idx to end_idx).

  4. No Match Case:

    • If no match is found after trying all n-gram sizes, the function returns an empty tensor.

Colors and Underline:

  • At the end of the code, there are some constants defined for color codes and underlining (e.g., COLORS and UNDERLINE). These are not used in the function and seem to be for formatting text in a terminal, but they are not relevant to the find_candidate_pred_tokens function.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def greedy_search_pld(
        self,
        input_ids: torch.LongTensor,
        logits_processor: Optional[LogitsProcessorList] = None,
        stopping_criteria: Optional[StoppingCriteriaList] = None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[Union[int, List[int]]] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_scores: Optional[bool] = None,
        return_dict_in_generate: Optional[bool] = None,
        synced_gpus: bool = False,
        streamer: Optional["BaseStreamer"] = None,
        draft_matching_window_size = 3,
        draft_num_candidate_tokens = 10,
        print_output=True,
        **model_kwargs,
    ):

 
        global tokenizer

        # init values
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
        pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
  
        # # init attention / hidden states / scores tuples
        scores = () if (return_dict_in_generate and output_scores) else None
        max_len = stopping_criteria[0].max_length
        i = 0
        current_color_index = 0

        while True:
            i += 1
            cur_len = input_ids.shape[-1]
            candidate_pred_tokens = find_candidate_pred_tokens(input_ids, draft_matching_window_size, draft_num_candidate_tokens)
            if len(candidate_pred_tokens) == 0:
                candidate_pred_tokens = torch.tensor([100], device=input_ids.device).unsqueeze(0)
            else:
                candidate_pred_tokens = candidate_pred_tokens.unsqueeze(0)

            candidate_input_ids = torch.cat((input_ids, candidate_pred_tokens), dim=1)
            candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1]
            candidate_kwargs = copy.copy(model_kwargs)

            # If you need to extend attention mask and token type ids,
            # you'll need to implement the logic for these methods based on your model's requirements.
            # For instance, for attention mask you might do:
            # if "attention_mask" in candidate_kwargs:
            #     candidate_kwargs["attention_mask"] = torch.cat(
            #         [candidate_kwargs["attention_mask"], torch.ones((candidate_kwargs["attention_mask"].shape[0], candidate_length), dtype=torch.long, device=candidate_kwargs["attention_mask"].device)],
            #         dim=1,
            #     )

            candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]). ## change to prepare_attention_mask!
            candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1])
            model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs)

            # prepare model inputs
            # model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
            )

            new_logits = outputs.logits[:, -candidate_length - 1 :]  # excludes the input prompt if present
            selected_tokens = new_logits.argmax(dim=-1)
            candidate_new_tokens = candidate_input_ids[:, -candidate_length:]
            n_matches = ((~(candidate_new_tokens == selected_tokens[:, :-1])).cumsum(dim=-1) < 1).sum()

            # if last_assistant_token_is_eos and n_matches == candidate_length: # todo: do this earlier somehow
            #     n_matches -= 1
            n_matches = min(n_matches, max_len - cur_len - 1)
            # print(n_matches)
            # i+= n_matches.item()

            if print_output:
                current_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
            valid_tokens = selected_tokens[:, : n_matches + 1]
            input_ids = torch.cat((input_ids, valid_tokens), dim=-1)
            new_cur_len = input_ids.shape[-1]

            if print_output:
                updated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
                # Find and print the newly added text
                if updated_text != current_text:
                    new_text = updated_text[len(current_text):]
                    if len(valid_tokens[0]) > 1:
                        color = COLORS[current_color_index]
                        print(f"{color}{new_text}{RESET}", end='')
                        # Update color for next generation
                        current_color_index = (current_color_index + 1) % len(COLORS)
                    else:
                        print(f"{new_text}", end='')

            new_cache_size = new_cur_len - 1
            outputs.past_key_values = _crop_past_key_values(self, outputs.past_key_values, new_cache_size)
            model_kwargs["past_key_values"] = outputs.past_key_values
            # stop if we exceed the maximum length
            if (valid_tokens == eos_token_id_tensor.item()).any():
                break

            if stopping_criteria(input_ids, scores):
                break

        if return_dict_in_generate:
            return GreedySearchDecoderOnlyOutput(
                sequences=input_ids,
                scores=scores,
                # attentions=decoder_attentions,
                # hidden_states=decoder_hidden_states,
            )
        else:
            return input_ids

This code defines a method called greedy_search_pld for generating text from a language model using a customized greedy search strategy. This method incorporates a draft prediction mechanism, where it tries to match candidate tokens with the model’s predictions and outputs them in a color-coded format for visualization.

Parameters:

  • input_ids: The sequence of token IDs representing the input prompt.
  • logits_processor: (Optional) A list of processors to modify the logits (e.g., filtering certain tokens).
  • stopping_criteria: (Optional) A list of criteria that determine when the generation should stop.
  • max_length: (Optional) The maximum length of the generated sequence.
  • pad_token_id: (Optional) The ID of the padding token.
  • eos_token_id: (Optional) The ID of the end-of-sequence token. It can be a single ID or a list of IDs.
  • output_attentions: (Optional) Whether to return attention scores.
  • output_hidden_states: (Optional) Whether to return hidden states.
  • output_scores: (Optional) Whether to return the scores of the predictions.
  • return_dict_in_generate: (Optional) Whether to return a dictionary with detailed generation information.
  • synced_gpus: Whether to synchronize GPUs during generation (for multi-GPU setups).
  • streamer: (Optional) A streamer for streaming token outputs during generation.
  • draft_matching_window_size: The window size used for matching candidate tokens (n-grams).
  • draft_num_candidate_tokens: The number of tokens to consider when predicting the next tokens.
  • print_output: A boolean indicating whether to print the generated output with color coding.
  • **model_kwargs: Additional arguments for the model.

Main Steps:

  1. Initialization:

    • The function starts by initializing the stopping_criteria, pad_token_id, and eos_token_id based on either provided values or the model’s configuration.
    • It converts the eos_token_id to a tensor if it is provided.
  2. Loop for Token Generation:

    • The method enters a loop where it generates tokens one by one or in small groups until it meets the stopping criteria.
    • The find_candidate_pred_tokens function is used to search for sequences of tokens in the input (input_ids) and predict the next possible tokens (candidate_pred_tokens).
    • If no match is found, it defaults to predicting a token with ID 100.
  3. Prepare Candidate Tokens:

    • The candidate tokens are appended to the input_ids, and the model is run on this extended sequence to generate the next token logits (new_logits).
    • The function then selects the most likely tokens (selected_tokens) from the logits.
  4. Matching Predicted Tokens:

    • The method compares the selected_tokens with the candidate tokens and determines how many of them match (n_matches).
    • It then selects only the valid matching tokens (valid_tokens) to append to the input_ids.
  5. Print the Output:

    • If print_output is enabled, the method prints the newly generated text in color-coded segments. Each segment of newly generated text gets a different color for visualization.
  6. Stopping Conditions:

    • The loop checks whether the generated tokens include the eos_token_id or if the stopping criteria are met. If either condition is true, the loop breaks, and the generation stops.
  7. Return the Result:

    • If return_dict_in_generate is set to True, the method returns a dictionary (GreedySearchDecoderOnlyOutput) with the generated sequence and any requested additional data (e.g., scores).
    • If not, it returns just the generated sequence (input_ids).

Key Features:

  • Greedy Search: The method implements a greedy search, which means it always selects the most likely next token at each step.
  • Draft Prediction: It incorporates a mechanism to predict and match tokens before finalizing them.
  • Color-Coded Output: The generated text is printed with different colors for easy visualization of the newly generated segments.

Reference

X: Prompt Lookup Decode demo: https://twitter.com/joao_gante/status/1747322413006643259 HuggingFace: https://huggingface.co/docs/transformers/generation_strategies