Integration: vLLM Invocation Layer

Use a vLLM server or locally hosted instance in your Prompt Node

Authors
Lukas Kreussel

PyPI - Version PyPI - Python Version

Simply use vLLM in your haystack pipeline, to utilize fast, self-hosted LLMs.

vLLM Haystack

Installation

Install the wrapper via pip: pip install vllm-haystack

Usage

This integration provides two invocation layers:

  • vLLMInvocationLayer: To use models hosted on a vLLM server
  • vLLMLocalInvocationLayer: To use locally hosted vLLM models

Use a Model Hosted on a vLLM Server

To utilize the wrapper the vLLMInvocationLayer has to be used.

Here is a simple example of how a PromptNode can be created with the wrapper.

from haystack.nodes import PromptNode, PromptModel
from vllm_haystack import vLLMInvocationLayer


model = PromptModel(model_name_or_path="", invocation_layer_class=vLLMInvocationLayer, max_length=256, api_key="EMPTY", model_kwargs={
        "api_base" : API, # Replace this with your API-URL
        "maximum_context_length": 2048,
    })

prompt_node = PromptNode(model_name_or_path=model, top_k=1, max_length=256)

The model will be inferred based on the model served on the vLLM server. For more configuration examples, take a look at the unit-tests.

Hosting a vLLM Server

To create an OpenAI-Compatible Server via vLLM you can follow the steps in the Quickstart section of their documentation.

Use a Model Hosted Locally

⚠️To run vLLM locally you need to have vllm installed and a supported GPU.

If you don’t want to use an API-Server this wrapper also provides a vLLMLocalInvocationLayer which executes the vLLM on the same node Haystack is running on.

Here is a simple example of how a PromptNode can be created with the vLLMLocalInvocationLayer.

from haystack.nodes import PromptNode, PromptModel
from vllm_haystack import vLLMLocalInvocationLayer

model = PromptModel(model_name_or_path=MODEL, invocation_layer_class=vLLMLocalInvocationLayer, max_length=256, model_kwargs={
        "maximum_context_length": 2048,
    })

prompt_node = PromptNode(model_name_or_path=model, top_k=1, max_length=256)