تجربة عملية على PyTorch 2.0: تسريع نماذج HuggingFace وTIMM!

يمكن لـ PyTorch 2.0 زيادة سرعة تدريب النموذج بمقدار 30%-200% باستخدام سطر بسيط من torch.compile(). سوف يوضح لك هذا البرنامج التعليمي كيفية إعادة إنتاج هذا التسريع فعليًا.
torch.compile() يستطيعجرب بسهولة برامج التجميع الخلفية المختلفة،يؤدي هذا إلى تسريع تشغيل كود PyTorch. إنه مثل torch.jit.script() بديل مباشر لـ ، والذي يمكن تشغيله مباشرة على nn.Module دون تعديل الكود المصدر.
في المقالة السابقة، قدمنا أن torch.compile يدعم كود PyTorch التعسفي، وتدفق التحكم، والطفرة، وإلى حد ما يدعم الأشكال الديناميكية.
من خلال اختبار 163 نموذجًا مفتوح المصدر، وجدنا أن torch.compile() يمكن أن يحقق تسريعًا يتراوح بين 30% و200%.
opt_module = torch.compile(module)
نتائج الاختبار مفصلة في:
سوف يوضح لك هذا البرنامج التعليمي كيفية الاستخدام torch.compile() تسريع تدريب النموذج.
المتطلبات والإعدادات
بالنسبة لوحدات معالجة الرسومات (تتمتع وحدات معالجة الرسومات الأحدث بتحسينات أداء أكثر أهمية):
pip3 install numpy --pre torch[dynamo] --force-reinstall --extra-index-url https://download.pytorch.org/whl/nightly/cu117
بالنسبة لوحدة المعالجة المركزية:
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
اختياري: التحقق من التثبيت
git clone https://github.com/pytorch/pytorch
cd tools/dynamo
python verify_dynamo.py
اختياري: تثبيت Docker
يتم توفير جميع التبعيات الضرورية في ملف PyTorch Nightly Binaries، والذي يمكن تنزيله عبر:
docker pull ghcr.io/pytorch/pytorch-nightly
للتجارب المخصصة،تأكد فقط من أن الحاوية يمكنها الوصول إلى جميع وحدات معالجة الرسومات:
docker run --gpus all -it ghcr.io/pytorch/pytorch-nightly:latest /bin/bash
يبدأ
مثال بسيط
دعونا نلقي نظرة على مثال بسيط أولاً، ونلاحظ أن التسارع أصبح أكثر وضوحًا مع وحدات معالجة الرسومات الأحدث.
import torch
def fn(x, y):
a = torch.sin(x).cuda()
b = torch.sin(y).cuda()
return a + b
new_fn = torch.compile(fn, backend="inductor")
input_tensor = torch.randn(10000).to(device="cuda:0")
a = new_fn()
لن يؤدي هذا المثال فعليًا إلى زيادة السرعة، ولكن يمكن استخدامه للبدء.
في هذا المثال،torch.cos() و الشعلة.الخطيئة() هي أمثلة على العمليات النقطية، فهي تعمل على المتجهات عنصرًا بعنصر. نقطة عملية أكثر شهرة هي شعلة.relu().
إن العمليات التي تتم نقطة بنقطة في الوضع الحريص ليست مثالية لأن كل مشغل يحتاج إلى قراءة موتر من الذاكرة، وإجراء بعض التغييرات، ثم كتابة هذه التغييرات مرة أخرى.
أحد أهم التحسينات في PyTorch 2.0 هو الاندماج.
لذلك في هذه الحالة يمكننا تحويل قراءتين وكتابتين إلى قراءة واحدة وكتابة واحدة، وهو أمر بالغ الأهمية في وحدات معالجة الرسومات الأحدث حيث يكون عنق الزجاجة هو عرض النطاق الترددي للذاكرة (مدى سرعة إرسال البيانات إلى وحدة معالجة الرسومات) بدلاً من الحوسبة (مدى سرعة وحدة معالجة الرسومات في إجراء عمليات النقطة العائمة).
التحسين الثاني المهم في PyTorch 2.0 هو الرسوم البيانية CUDA.
تساعد رسوم بيانية CUDA في التخلص من التكلفة الإضافية لتشغيل نوى فردية من برامج Python.
يدعم torch.compile() العديد من الخوادم الخلفية المختلفة، وأبرزها Inductor، الذي يمكنه إنشاء نوى Triton.
تمت كتابة هذه النوى بلغة بايثون.لكنها أفضل من معظم أنوية CUDA المكتوبة بخط اليد.بافتراض أن المثال أعلاه يسمى trig.py، يمكنك في الواقع فحص الكود الذي يولد نواة triton عن طريق تشغيل
TORCHINDUCTOR_TRACE=1 python trig.py
@pointwise(size_hints=[16384], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32'}, 'device': 0, 'constants': {}, 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]})
@triton.jit
def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr):
xnumel = 10000
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK])
xmask = xindex < xnumel
x0 = xindex
tmp0 = tl.load(in_ptr0 + (x0), xmask)
tmp1 = tl.sin(tmp0)
tmp2 = tl.sin(tmp1)
tl.store(out_ptr0 + (x0 + tl.zeros([XBLOCK], tl.int32)), tmp2, xmask)
ومن الكود أعلاه، يمكننا أن نرى أن: الخطايا لقد حدث الاندماج لأن الاثنين الخطيئة تتم العمليات في نواة Triton، ويتم تخزين المتغيرات المؤقتة في السجلات، والتي يمكن الوصول إليها بسرعة كبيرة.
مثال نموذجي حقيقي
خذ resnet50 في PyTorch Hub كمثال:
import torch
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
opt_model = torch.compile(model, backend="inductor")
model(torch.randn(1,3,64,64))
في التشغيل الفعلي، ستجد أن التشغيل الأول بطيء جدًا لأن النموذج قيد التجميع. ستكون سرعة التشغيل اللاحقة أسرع.لذلك قبل البدء في عملية القياس والتحليل، من الشائع أن نقوم بتسخين النموذج.
كما ترى، نستخدم "inductor" هنا لتمثيل اسم المترجم، ولكنه ليس البرنامج الخلفي الوحيد المتاح. يمكنك تشغيله في REPL torch._dynamo.list_backends() لرؤية القائمة الكاملة للواجهات الخلفية المتاحة.
يمكنك أيضًا المحاولة aot_cudagraphs أو موزع الشبكة .
مثال على نموذج وجه العناق
غالبًا ما يستخدم مجتمع PyTorch نماذج مدربة مسبقًا للمحولات أو TIMM:
أحد أهداف تصميم PyTorch 2.0 هو أن أي مجموعة تجميع يجب أن تكون قادرة على الاستخدام خارج الصندوق في الغالبية العظمى من النماذج التي يتم تشغيلها بالفعل.
هنا نقوم بتنزيل نموذج مدرب مسبقًا مباشرةً من مركز HuggingFace وتحسينه:
import torch
from transformers import BertTokenizer, BertModel
# Copy pasted from here https://huggingface.co/bert-base-uncased
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")
model = torch.compile(model) # This is the only line of code that we changed
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")
output = model(**encoded_input)
إذا قمت بإزالته من النموذج إلى(الجهاز="cuda:0") و المدخلات المشفرة سوف يقوم PyTorch 2.0 بإنشاء نوى C++ مُحسّنة للعمل على وحدة المعالجة المركزية.
يمكنك التحقق من نوى Triton أو C++ الخاصة بـ BERT، والتي هي أكثر تعقيدًا بشكل واضح من الأمثلة المثلثية المذكورة أعلاه. ولكن إذا كنت تعرف PyTorch فيمكنك تخطيه.
يمكن استخدام نفس الكود مع ما يلي للحصول على نتائج أفضل:
* https://github.com/huggingface/accelerate
* دي دي بي
مرة أخرى، حاول استخدام مثال TIMM:
import timm
import torch
model = timm.create_model('resnext101_32x8d', pretrained=True, num_classes=2)
opt_model = torch.compile(model, backend="inductor")
opt_model(torch.randn(64,3,7,7))
الهدف من PyTorch هو بناء مُجمِّع قادر على التكيف مع المزيد من النماذج وتسريع تشغيل معظم النماذج مفتوحة المصدر.قم بزيارة HuggingFace Hub الآن،تسريع نموذج TIMM مع PyTorch 2.0!