MMLU Dataset and Performance

MMLU 是單選還是復選?應該是單選題。但是標準的 prompt 確是說 multiple choices. 是為了 confuse LLM?

Source

MMLU

  • 測試集 dataset: hendrycks MMLU数据集:https://github.com/hendrycks/test
  • 測試 code: ollmer: [GitHub - ollmer/mmlu: Measuring Massive Multitask Language Understanding ICLR 2021](https://github.com/ollmer/mmlu)
  • 測試 code: deepeval (JUNK!): GitHub - confident-ai/deepeval: The LLM Evaluation Framework
  • Code: in ml_code/llm_evaluation_4_mmlu/evaluate_hf.ipynb and evaluate_llama.ipynb

MMLU Pro

MMLU和MMLU-Pro的比較

特徵 MMLU MMLU-Pro
範圍和內容 包含各種領域的廣泛問題集,主要以知識為主。評估模型的記憶和理解能力。 在MMLU的基礎上添加了更複雜的推理問題。重點在於評估高階認知技能,如問題解決和批判性思維。
難度等級 包含混合難度的問題,其中一些相對簡單或瑣碎。 通過去除簡單和噪聲問題並整合需要更深層推理的問題,顯著提高了挑戰性。
選項數量 每個問題提供四個選項。 選項擴展到十個,增加了難度,減少了隨機猜對的可能性。
準確性和敏感性 當前模型已達到高準確度,導致性能趨於平緩,對提示變化敏感(4-5%敏感性)。 由於難度增加,準確率顯著下降(比MMLU低16%到33%)。對提示變化的敏感性減少到僅2%,顯示出更大的穩定性和穩健性。
推理與直接回答 模型通常在直接回答技術上表現良好。 使用鏈式思考(CoT)推理的模型表現優於直接回答的模型,強調數據集對複雜推理任務的關注。

Introduction

大模型(LLM)的评测是衡量大模型效果的关键步骤,也是模型流水线中必不可少的过程。常见的大模型排行榜或平台有🤗 Open LLM LeaderboardOpenCompassChatbot Arena Leaderboard.

那么,大模型的评测是如何实现的呢?

本文将会以MMLU数据集为例,考察主流开源大模型,如LLAMA-2, BaiChuan-2等模型的评估实现及结果,希望能管中规豹,一探究竟。

NLP(七十八)大模型探索:MMLU数据集评测 - My Github Blog (percent4.github.io)

MMLU数据集

MMLU数据集已开源至Github平台,访问网址为:https://github.com/hendrycks/test .

MMLU(Massive Multitask Language Understanding)是一个新的基准,用于衡量在零样本(zero-shot)和少样本(few-shot)情形下,大模型在预训练期间获得的世界知识。这使得该基准测试更具挑战性,也更类似于我们评估人类的方式。该基准涵盖 STEM、人文(humanities)、社会科学(social sciences)等领域的 57 个学科(subject)。 它的难度从初级到高级,既考验世界知识,又考验解决问题的能力。 学科范围从数学和历史等传统领域到法律和伦理等更为专业的领域。学科的粒度和广度使该基准成为识别模型盲点的理想选择。

類別和子類別

MMLU 數據集的類別和子類別結構:

  • STEM (6 sub-categories)
    • Mathematics
    • Physics
    • Chemistry
    • Biology
    • Computer Science
    • Engineering
  • Humanities (3 sub-categories)
    • History
    • Philosophy
    • Law
  • Social Sciences (4 sub-categories)
    • Economics
    • Psychology
    • Political Science
    • Geography
  • **Other **(4 sub-categories)
    • Health
    • Culture
    • Business
    • Other

MMLU数据集共收集了15908个问题,并将其分为few-shot开发集、验证集和测试集。 few-shot开发集每个学科有5个问题,验证集可用于选择超参数,由1540个问题组成,测试集有14079个问题。 每个学科至少包含100个测试问题,这比大多数旨在评估人类的考试都要长。

我们来看其中一个示例:

1
2
3
4
5
6
7
8
9
Question: Glucose is transported into the muscle cell:

Choices:
A. via protein transporters called GLUT4.
B. only in the presence of insulin.
C. via hexokinase.
D. via monocarbylic acid transporters.

Correct answer: A

Dataset 的格式: [question, subject, choices, answer]

  • question: 就是問題本身
  • subject: 就是 category,一共有 57 分類
  • choices: 4 個可能的選擇,對應 {A, B, C, D}
  • answer: 真正的答案

image-20240622213641055

不同大語言模型 MMLU 性能

57 類別一般分成 4 大類: STEM (Science, Technology, ? 科學技術相關), Humanities (Law, Philosophy, 從亞里士多德就存在), Social Science (Economics, Psychology, ?, 想依附到科學), Others (醫學相關)

STEM (19 學門, 1800 題) Humanities (13 學門, 1638 題) Social Sciences (12 學門, 1368 題) Other (13 學門, 1890 題)
ABSTRACT_ALGEBRA FORMAL_LOGIC ECONOMETRICS BUSINESS_ETHICS
ANATOMY HIGH_SCHOOL_
EUROPEAN_HISTORY
HIGH_SCHOOL_
GEOGRAPHY
CLINICAL_KNOWLEDGE
ASTRONOMY HIGH_SCHOOL_
US_HISTORY
HIGH_SCHOOL_
GOVERNMENT_AND_POLITICS
COLLEGE_MEDICINE
COLLEGE_BIOLOGY HIGH_SCHOOL_
WORLD_HISTORY
HIGH_SCHOOL
_MACROECONOMICS
GLOBAL_FACTS
COLLEGE_CHEMISTRY INTERNATIONAL_LAW HIGH_SCHOOL
_MICROECONOMICS
HUMAN_AGING
COLLEGE_COMPUTER_SCIENCE JURISPRUDENCE HIGH_SCHOOL_
PSYCHOLOGY
MANAGEMENT
COLLEGE_MATHEMATICS LOGICAL_FALLACIES HUMAN_SEXUALITY MARKETING
COLLEGE_PHYSICS MORAL_DISPUTES PROFESSIONAL_PSYCHOLOGY MEDICAL_GENETICS
COMPUTER_SECURITY MORAL_SCENARIOS PUBLIC_RELATIONS MISCELLANEOUS
CONCEPTUAL_PHYSICS PHILOSOPHY SECURITY_STUDIES NUTRITION
ELECTRICAL_ENGINEERING PREHISTORY SOCIOLOGY PROFESSIONAL_ACCOUNTING
ELEMENTARY_MATHEMATICS PROFESSIONAL_LAW US_FOREIGN_POLICY PROFESSIONAL_MEDICINE
HIGH_SCHOOL_BIOLOGY WORLD_RELIGIONS   VIROLOGY
HIGH_SCHOOL_
CHEMISTRY
     
HIGH_SCHOOL_
COMPUTER_SCIENCE
     
HIGH_SCHOOL_
MATHEMATICS
     
HIGH_SCHOOL_
PHYSICS
     
HIGH_SCHOOL_
STATISTICS
     
MACHINE_LEARNING      

4 大類 (categories), 17 子類 (sub-categories)

1
2
3
4
5
categories = {
    "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
    "humanities": ["history", "philosophy", "law"],
    "social sciences": ["politics", "culture", "economics", "geography", "psychology"],
    "other (business, health, misc.)": ["other", "business", "health"],

Use LLM_EVALUATION_4_MMLU! (official)

官方數據

image-20240629113533732

自己執行的結果:

模型 Accuracy STEM (18,1800) Humanities (13,1638) Social Sciences (12, 1368) Others (14, 1890)
Llama3-8B (5 shot) 65.0 55.8 58.9 76.3 71.6
Llama2-7B (5 shot) 46.0 37.0 43.3 51.8 52.4
Mistral-7B (5 shot) 62.6 52.6 56.5 73.5 70.4
Phi3-3.8B-4K (5 shot) 69.2 59.8 65.4 80.1 73.1
  • Phi-3-3.8B 表現勝過 Llama3-8B? 不確定是否因為 training 過
  • All models: 最差的是 STEM, 最好的是 Social Sciences.
  • All models: STEM 最差的是 math, Humanities 最差的是 law.

Microsoft Phi3

Category and Sub-category

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Average accuracy 0.483 - math
Average accuracy 0.707 - health
Average accuracy 0.606 - physics
Average accuracy 0.826 - business
Average accuracy 0.837 - biology
Average accuracy 0.558 - chemistry
Average accuracy 0.653 - computer science
Average accuracy 0.733 - economics
Average accuracy 0.593 - engineering
Average accuracy 0.681 - philosophy
Average accuracy 0.729 - other
Average accuracy 0.792 - history
Average accuracy 0.869 - geography
Average accuracy 0.818 - politics
Average accuracy 0.817 - psychology
Average accuracy 0.828 - culture
Average accuracy 0.551 - law
Average accuracy 0.598 - STEM
Average accuracy 0.654 - humanities
Average accuracy 0.801 - social sciences
Average accuracy 0.731 - other (business, health, misc.)
Average accuracy: 0.692

Subject

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
Average accuracy 0.370 - abstract_algebra
Average accuracy 0.667 - anatomy
Average accuracy 0.776 - astronomy
Average accuracy 0.680 - business_ethics
Average accuracy 0.743 - clinical_knowledge
Average accuracy 0.826 - college_biology
Average accuracy 0.470 - college_chemistry
Average accuracy 0.560 - college_computer_science
Average accuracy 0.360 - college_mathematics
Average accuracy 0.688 - college_medicine
Average accuracy 0.363 - college_physics
Average accuracy 0.780 - computer_security
Average accuracy 0.706 - conceptual_physics
Average accuracy 0.456 - econometrics
Average accuracy 0.593 - electrical_engineering
Average accuracy 0.524 - elementary_mathematics
Average accuracy 0.603 - formal_logic
Average accuracy 0.390 - global_facts
Average accuracy 0.842 - high_school_biology
Average accuracy 0.601 - high_school_chemistry
Average accuracy 0.730 - high_school_computer_science
Average accuracy 0.812 - high_school_european_history
Average accuracy 0.869 - high_school_geography
Average accuracy 0.912 - high_school_government_and_politics
Average accuracy 0.749 - high_school_macroeconomics
Average accuracy 0.400 - high_school_mathematics
Average accuracy 0.840 - high_school_microeconomics
Average accuracy 0.444 - high_school_physics
Average accuracy 0.883 - high_school_psychology
Average accuracy 0.625 - high_school_statistics
Average accuracy 0.804 - high_school_us_history
Average accuracy 0.793 - high_school_world_history
Average accuracy 0.691 - human_aging
Average accuracy 0.771 - human_sexuality
Average accuracy 0.851 - international_law
Average accuracy 0.796 - jurisprudence
Average accuracy 0.804 - logical_fallacies
Average accuracy 0.554 - machine_learning
Average accuracy 0.806 - management
Average accuracy 0.897 - marketing
Average accuracy 0.810 - medical_genetics
Average accuracy 0.825 - miscellaneous
Average accuracy 0.757 - moral_disputes
Average accuracy 0.583 - moral_scenarios
Average accuracy 0.752 - nutrition
Average accuracy 0.765 - philosophy
Average accuracy 0.775 - prehistory
Average accuracy 0.582 - professional_accounting
Average accuracy 0.510 - professional_law
Average accuracy 0.761 - professional_medicine
Average accuracy 0.758 - professional_psychology
Average accuracy 0.736 - public_relations
Average accuracy 0.763 - security_studies
Average accuracy 0.866 - sociology
Average accuracy 0.860 - us_foreign_policy
Average accuracy 0.494 - virology
Average accuracy 0.825 - world_religions

Llama3-8B

Category and Sub-category

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Average accuracy 0.444 - math
Average accuracy 0.705 - health
Average accuracy 0.548 - physics
Average accuracy 0.826 - business
Average accuracy 0.769 - biology
Average accuracy 0.521 - chemistry
Average accuracy 0.636 - computer science
Average accuracy 0.655 - economics
Average accuracy 0.634 - engineering
Average accuracy 0.580 - philosophy
Average accuracy 0.689 - other
Average accuracy 0.780 - history
Average accuracy 0.813 - geography
Average accuracy 0.810 - politics
Average accuracy 0.780 - psychology
Average accuracy 0.825 - culture
Average accuracy 0.499 - law
Average accuracy 0.558 - STEM
Average accuracy 0.589 - humanities
Average accuracy 0.763 - social sciences
Average accuracy 0.716 - other (business, health, misc.)
Average accuracy: 0.650

Subject

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
Average accuracy 0.320 - abstract_algebra
Average accuracy 0.644 - anatomy
Average accuracy 0.697 - astronomy
Average accuracy 0.630 - business_ethics
Average accuracy 0.751 - clinical_knowledge
Average accuracy 0.757 - college_biology
Average accuracy 0.500 - college_chemistry
Average accuracy 0.560 - college_computer_science
Average accuracy 0.360 - college_mathematics
Average accuracy 0.636 - college_medicine
Average accuracy 0.471 - college_physics
Average accuracy 0.800 - computer_security
Average accuracy 0.574 - conceptual_physics
Average accuracy 0.412 - econometrics
Average accuracy 0.634 - electrical_engineering
Average accuracy 0.437 - elementary_mathematics
Average accuracy 0.508 - formal_logic
Average accuracy 0.310 - global_facts
Average accuracy 0.774 - high_school_biology
Average accuracy 0.532 - high_school_chemistry
Average accuracy 0.680 - high_school_computer_science
Average accuracy 0.758 - high_school_european_history
Average accuracy 0.813 - high_school_geography
Average accuracy 0.896 - high_school_government_and_politics
Average accuracy 0.659 - high_school_macroeconomics
Average accuracy 0.374 - high_school_mathematics
Average accuracy 0.765 - high_school_microeconomics
Average accuracy 0.411 - high_school_physics
Average accuracy 0.846 - high_school_psychology
Average accuracy 0.639 - high_school_statistics
Average accuracy 0.828 - high_school_us_history
Average accuracy 0.827 - high_school_world_history
Average accuracy 0.700 - human_aging
Average accuracy 0.771 - human_sexuality
Average accuracy 0.860 - international_law
Average accuracy 0.759 - jurisprudence
Average accuracy 0.724 - logical_fallacies
Average accuracy 0.518 - machine_learning
Average accuracy 0.874 - management
Average accuracy 0.889 - marketing
Average accuracy 0.790 - medical_genetics
Average accuracy 0.814 - miscellaneous
Average accuracy 0.717 - moral_disputes
Average accuracy 0.411 - moral_scenarios
Average accuracy 0.755 - nutrition
Average accuracy 0.723 - philosophy
Average accuracy 0.725 - prehistory
Average accuracy 0.479 - professional_accounting
Average accuracy 0.452 - professional_law
Average accuracy 0.739 - professional_medicine
Average accuracy 0.721 - professional_psychology
Average accuracy 0.745 - public_relations
Average accuracy 0.755 - security_studies
Average accuracy 0.861 - sociology
Average accuracy 0.850 - us_foreign_policy
Average accuracy 0.566 - virology
Average accuracy 0.836 - world_religions

Mistral-7B

Category and Sub-category

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
Average accuracy 0.401 - math
Average accuracy 0.684 - health
Average accuracy 0.506 - physics
Average accuracy 0.796 - business
Average accuracy 0.756 - biology
Average accuracy 0.518 - chemistry
Average accuracy 0.612 - computer science
Average accuracy 0.636 - economics
Average accuracy 0.579 - engineering
Average accuracy 0.535 - philosophy
Average accuracy 0.699 - other
Average accuracy 0.766 - history
Average accuracy 0.768 - geography
Average accuracy 0.779 - politics
Average accuracy 0.747 - psychology
Average accuracy 0.813 - culture
Average accuracy 0.492 - law
Average accuracy 0.526 - STEM
Average accuracy 0.565 - humanities
Average accuracy 0.735 - social sciences
Average accuracy 0.704 - other (business, health, misc.)
Average accuracy: 0.626

Subject

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
Average accuracy 0.270 - abstract_algebra
Average accuracy 0.630 - anatomy
Average accuracy 0.658 - astronomy
Average accuracy 0.570 - business_ethics
Average accuracy 0.691 - clinical_knowledge
Average accuracy 0.729 - college_biology
Average accuracy 0.500 - college_chemistry
Average accuracy 0.520 - college_computer_science
Average accuracy 0.400 - college_mathematics
Average accuracy 0.653 - college_medicine
Average accuracy 0.382 - college_physics
Average accuracy 0.770 - computer_security
Average accuracy 0.574 - conceptual_physics
Average accuracy 0.491 - econometrics
Average accuracy 0.579 - electrical_engineering
Average accuracy 0.378 - elementary_mathematics
Average accuracy 0.405 - formal_logic
Average accuracy 0.360 - global_facts
Average accuracy 0.768 - high_school_biology
Average accuracy 0.527 - high_school_chemistry
Average accuracy 0.680 - high_school_computer_science
Average accuracy 0.788 - high_school_european_history
Average accuracy 0.768 - high_school_geography
Average accuracy 0.865 - high_school_government_and_politics
Average accuracy 0.662 - high_school_macroeconomics
Average accuracy 0.341 - high_school_mathematics
Average accuracy 0.664 - high_school_microeconomics
Average accuracy 0.331 - high_school_physics
Average accuracy 0.824 - high_school_psychology
Average accuracy 0.579 - high_school_statistics
Average accuracy 0.789 - high_school_us_history
Average accuracy 0.772 - high_school_world_history
Average accuracy 0.700 - human_aging
Average accuracy 0.786 - human_sexuality
Average accuracy 0.777 - international_law
Average accuracy 0.778 - jurisprudence
Average accuracy 0.791 - logical_fallacies
Average accuracy 0.491 - machine_learning
Average accuracy 0.825 - management
Average accuracy 0.880 - marketing
Average accuracy 0.740 - medical_genetics
Average accuracy 0.816 - miscellaneous
Average accuracy 0.708 - moral_disputes
Average accuracy 0.328 - moral_scenarios
Average accuracy 0.758 - nutrition
Average accuracy 0.695 - philosophy
Average accuracy 0.735 - prehistory
Average accuracy 0.493 - professional_accounting
Average accuracy 0.450 - professional_law
Average accuracy 0.688 - professional_medicine
Average accuracy 0.678 - professional_psychology
Average accuracy 0.673 - public_relations
Average accuracy 0.727 - security_studies
Average accuracy 0.831 - sociology
Average accuracy 0.860 - us_foreign_policy
Average accuracy 0.548 - virology
Average accuracy 0.830 - world_religions

Appendix

评测代码

放在 ml_code/MMLU/…/.pyh

引入一般的庫。

1
2
3
4
5
6
7
8
9
10
11
# -*- encoding: utf-8 -*-

import argparse
import json
import os
import time

import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

這裏引入了 categories 和 subcategories 模塊中的類別和子類別信息。定義了多選題的選項,分別為 A, B, C, D。

1
2
3
from categories import categories, subcategories

choices = ["A", "B", "C", "D"]

將科目名稱中的下劃線替換為空格,使其更具可讀性。

1
2
3
4
5
6
def format_subject(subject):
    l = subject.split("_")
    s = ""
    for entry in l:
        s += " " + entry
    return s

產生 1-shot 的例子。將 pandas data frame (df) 中的單個問題格式化為文本提示。就是把 A, B, C, D 和可選擇的答案結合。如果 include_answer=True, 文本提示包括答案。

1
2
3
4
5
6
7
8
9
def format_example(df, idx, include_answer=True):
    prompt = df.iloc[idx, 0]  # question
    k = df.shape[1] - 2
    for j in range(k):
        prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1]) # combine A/B/C/D with choices
    prompt += "\nAnswer:"
    if include_answer:
        prompt += " {}\n\n".format(df.iloc[idx, k + 1]) # combine Answer for in-context learning
    return prompt

生成包含科目問題的訓練提示,並附上 k-shot 例子

1
2
3
4
5
6
7
8
9
def gen_prompt(train_df, subject, k=-1):
    prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
        format_subject(subject)
    )
    if k == -1:
        k = train_df.shape[0]
    for i in range(k):
        prompt += format_example(train_df, i)  # add k-shots
    return prompt

評估模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
@torch.no_grad()
def eval(args, subject, model, tokenizer, dev_df, test_df):
    cors = []
    all_probs = []
    answers = choices[: test_df.shape[1] - 2]

    for i in range(test_df.shape[0]):
        # get prompt and make sure it fits
        k = args.ntrain  
        prompt_end = format_example(test_df, i, include_answer=False)
        train_prompt = gen_prompt(dev_df, subject, k)
        prompt = train_prompt + prompt_end

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

        while input_ids.shape[-1] > 2048:
            k -= 1
            train_prompt = gen_prompt(dev_df, subject, k)
            prompt = train_prompt + prompt_end
            input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(
                model.device
            )

        label = test_df.iloc[i, test_df.shape[1] - 1]

        logits = model(input_ids=input_ids).logits[0, -1]  # get the last element

        probs = (
            torch.nn.functional.softmax(
                torch.tensor(
                    [
                        logits[tokenizer("A").input_ids[-1]],
                        logits[tokenizer("B").input_ids[-1]],
                        logits[tokenizer("C").input_ids[-1]],
                        logits[tokenizer("D").input_ids[-1]],
                    ]
                ).float(),
                dim=0,
            )
            .detach()
            .cpu()
            .numpy()
        )
        pred = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(probs)]

        cor = pred == label
        cors.append(cor)
        all_probs.append(probs)

    acc = np.mean(cors)
    cors = np.array(cors)

    all_probs = np.array(all_probs)
    print("Average accuracy {:.3f} - {}".format(acc, subject))

    return cors, acc, all_probs

主程式: 利用 -m 執行 model.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def main(args):
    model = AutoModelForCausalLM.from_pretrained(
        args.model,
        torch_dtype=torch.float16,
        load_in_8bit=False,
        low_cpu_mem_usage=True,
        device_map="auto",
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
    model.eval()
    subjects = sorted(
        [
            f.split("_test.csv")[0]
            for f in os.listdir(os.path.join(args.data_dir, "test"))
            if "_test.csv" in f
        ]
    )

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    if not os.path.exists(os.path.join(args.save_dir, "results_{}".format(args.model.split("/")[-1]))):
        os.makedirs(os.path.join(args.save_dir, "results_{}".format(args.model.split("/")[-1])))

    all_cors = []
    subcat_cors = {
        subcat: [] for subcat_lists in subcategories.values() for subcat in subcat_lists
    }
    cat_cors = {cat: [] for cat in categories}

    start_time = time.time()
    for subject in subjects:
        dev_df = pd.read_csv(
            os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
        )[: args.ntrain]
        test_df = pd.read_csv(
            os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
        )

        cors, acc, probs = eval(args, subject, model, tokenizer, dev_df, test_df)
        subcats = subcategories[subject]
        for subcat in subcats:
            subcat_cors[subcat].append(cors)
            for key in categories.keys():
                if subcat in categories[key]:
                    cat_cors[key].append(cors)
        all_cors.append(cors)

        test_df["{}_correct".format(args.model)] = cors
        for j in range(probs.shape[1]):
            choice = choices[j]
            test_df["{}_choice{}_probs".format(args.model, choice)] = probs[:, j]
        test_df.to_csv(
            os.path.join(
                args.save_dir, "results_{}".format(args.model.split("/")[-1]), "{}.csv".format(subject)
            ),
            index=None,
        )

    results = {"subcategories": {}, "categories": {}}
    for subcat in subcat_cors:
        subcat_acc = np.mean(np.concatenate(subcat_cors[subcat]))
        results["subcategories"][subcat] = subcat_acc
        print("Average accuracy {:.3f} - {}".format(subcat_acc, subcat))

    for cat in cat_cors:
        cat_acc = np.mean(np.concatenate(cat_cors[cat]))
        results["categories"][cat] = cat_acc
        print("Average accuracy {:.3f} - {}".format(cat_acc, cat))
    weighted_acc = np.mean(np.concatenate(all_cors))
    results["weighted_accuracy"] = weighted_acc
    print("Average accuracy: {:.3f}".format(weighted_acc))

    results_file = os.path.join(
        args.save_dir, "accuracies_{}.json".format(args.model.replace("/", "_"))
    )
    end_time = time.time()
    results["cost_time"] = end_time - start_time
    with open(results_file, "w") as f:
        f.write(json.dumps(results, indent=4))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--ntrain", "-k", type=int, default=5)  # ?
    parser.add_argument("--data_dir", "-d", type=str, default="data")
    parser.add_argument("--save_dir", "-s", type=str, default="results")
    parser.add_argument("--model", "-m", type=str)  # model
    args = parser.parse_args()
    main(args)

DONOT USE Deepeval (JUNK!)!!

  • 先用簡單問題測試一下

  • 容易卡前頭。要檢查是否都是 A, 因爲是 “Answer: B”
  • 容易卡後頭,要檢查一下,如下圖 (all D)
  • Single choice or multiple choice questions? Singl choice.

前面 ok, 後面好像有問題

image-20240629145034361

The following results seem completely worng!!

模型 Accuracy STEM (18,1800) Humanities (13,1638) Social Sciences (12, 1368)
Llama3-8B-Chinese-Chat (3 shot)   48 65 69
Llama3-8B (3 shot)   43 60 61
Llama2-7B (3 shot)   30 37 37
Mistral-7B (NG!)   NA NA NA
Phi3-4K (nshot=0!)   41 54 55

Remove below

Open AI API

2024/6/19 GPT-4o 和 GPT-3.5 Turbo 的價格差。

GPT-3.5 Turbo 便宜 10 倍,但是準確率也比較差。

image-20240619083358544

GPT 設定

第二步是調用 gpt API,幾個重點:

  • 設定模型本身:
    • model = “gpt-4o” or “gpt-3.5-turbo-0125”
    • 溫度: T = 0 是 greedy decode. T = 0.1 還是以穩定爲主。如果 T = 1 則是創意爲主。
    • max_tokens: 最大的 context length, 也就是 KV cache size
    • top_p, frequency_penalty, presence_penalty: ?
  • 使用結構化 message 的 input prompt:
    • role: 訂人設
    • content: text 的 question (實際的 input)。應該可以加上 image 的content.
  • 回傳 reponse, 其結構應該和 input prompt 一樣

    • message.content: 這是 answer

    • choices? 是同時產生幾個不同的 choices?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
!pip install openai
from openai import OpenAI
client = OpenAI(api_key='put the key')

def run_one_question(question: str):
    response = client.chat.completions.create(
        model="gpt-4o",  # "gpt-3.5-turbo"
        messages=[
            {
                "role": "system",
                "content": "You are an knowledge expert, you are supposed to answer the multi-choice question to derive your final answer as `The answer is ...`."
            },
            {
                "role": "user",
                "content": [
                    {
                    "type": "text",
                    "text": question
                    }
                ]
            }
        ],
        temperature=0.1,
        max_tokens=4096,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
    )

    return response.choices[0].message.content
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def form_options(options: list):
    option_str = 'Options are:\n'
    opts = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
    for opt, o in zip(options, opts):
        option_str += f'({o}): {opt}' + '\n'
    return option_str


def get_prediction(output):
    pattern = r"answer is \(?([ABCDEFGHIJ])\)?"
    match = re.search(pattern, output)
    if match:
        return match.group(1)
    else:
        print("extraction failed, do a random guess")
        return random.choice(['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'])
    
    
from tqdm import tqdm

print('----------------- Start Answering -------------------')
pbar = tqdm(dataset['test'], desc='Processing', unit='question')
for entry in pbar:
    prefix = prompts[entry['category']]  # prefix consists of examples, 4 shots
    query = prefix + 'Q: ' + entry['question'] + '\n' + form_options(entry['options']) + '\n'
    answer = run_one_question(query)  # answer is from GPT
    entry['solution'] = answer  # store the GPT answer, because entry['answer'] is used for ground truth, so use 'answer'
    answers.append(entry)
    prediction = get_prediction(answer) # get exact letter A, B, .. from GPT answer
    if entry["answer"] == prediction:   # compare ground truth with GPT answer
        success += 1
        per_category_accuracy[entry['category']][0] += 1
    else:
        fail += 1
        per_category_accuracy[entry['category']][1] += 1

    json_string = json.dumps(entry)
    file.write(json_string + '\n')

    success_rate = success / (success + fail)
    pbar.set_description(f'Processing (Success rate: {success_rate:.4f})')

for k, v in per_category_accuracy.items():
    if v[0] + v[1] > 0:
      print('accuracy: ', k, v[0] / (v[0] + v[1]))