Skip to content

groq_provider

GroqProvider dataclass

Bases: LLMProvider

Groq LLM Provider for Wintermute AI using LiteLLM.

Source code in wintermute/ai/providers/groq_provider.py
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
@dataclass
class GroqProvider(LLMProvider):
    """Groq LLM Provider for Wintermute AI using LiteLLM."""

    api_key: Optional[str] = None
    _name: str = "groq"
    _default_model: str = "groq/llama-3.3-70b-versatile"

    @property
    def name(self) -> str:
        return self._name

    @property
    def description(self) -> str:
        return "Inference routing via Groq Cloud (Llama, Mixtral). No specialized RAG knowledge."

    def list_models(self) -> list[ModelInfo]:
        """List available Groq models."""
        return [
            ModelInfo(
                "groq/llama-3.3-70b-versatile", "llama-3.3", 128_000, True, True, True
            ),
            ModelInfo(
                "groq/llama-3.1-70b-versatile", "llama-3.1", 128_000, True, True, True
            ),
            ModelInfo(
                "groq/llama-3.1-8b-instant", "llama-3.1", 128_000, True, True, True
            ),
        ]

    def chat(self, req: ChatRequest) -> ChatResponse:
        """Send a chat completion request to Groq using LiteLLM."""
        model_id = req.model or self._default_model

        # litellm expects groq/ prefix for models if not already present
        if not model_id.startswith("groq/"):
            model_id = f"groq/{model_id}"

        messages = [m.__dict__ for m in req.messages]

        tools_payload = None
        if req.tools:
            tools_payload = [
                {
                    "type": "function",
                    "function": {
                        "name": t.name,
                        "description": t.description,
                        "parameters": t.input_schema,
                    },
                }
                for t in req.tools
            ]

        start = time.time()
        try:
            response = litellm.completion(
                model=model_id,
                messages=messages,
                temperature=float(req.temperature),
                max_tokens=req.max_tokens,
                tools=tools_payload,
                tool_choice=req.tool_choice if tools_payload else None,
                api_key=self.api_key,
                response_format={"type": "json_object"}
                if req.response_format == "json"
                else None,
            )

        except Exception as e:
            raise RuntimeError(f"LiteLLM Groq completion failed: {e}") from e

        choice = response.choices[0]
        content_text = choice.message.content or ""
        tool_calls = []

        if hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
            for tc in choice.message.tool_calls:
                tool_calls.append(
                    ToolCall(
                        id=tc.id,
                        name=tc.function.name,
                        arguments=tc.function.arguments,
                    )
                )

        usage = getattr(response, "usage", {})
        latency = int((time.time() - start) * 1000)

        return ChatResponse(
            content=content_text,
            tool_calls=tool_calls,
            model=model_id,
            provider=self.name,
            prompt_tokens=getattr(usage, "prompt_tokens", 0),
            completion_tokens=getattr(usage, "completion_tokens", 0),
            latency_ms=latency,
        )

    def embed(
        self, texts: Iterable[str], model: Optional[str] = None
    ) -> list[list[float]]:
        return [[0.0] * 1536 for _ in texts]

    def count_tokens(self, text: str, model: Optional[str] = None) -> int:
        return max(1, len(text) // 4)

chat(req)

Send a chat completion request to Groq using LiteLLM.

Source code in wintermute/ai/providers/groq_provider.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def chat(self, req: ChatRequest) -> ChatResponse:
    """Send a chat completion request to Groq using LiteLLM."""
    model_id = req.model or self._default_model

    # litellm expects groq/ prefix for models if not already present
    if not model_id.startswith("groq/"):
        model_id = f"groq/{model_id}"

    messages = [m.__dict__ for m in req.messages]

    tools_payload = None
    if req.tools:
        tools_payload = [
            {
                "type": "function",
                "function": {
                    "name": t.name,
                    "description": t.description,
                    "parameters": t.input_schema,
                },
            }
            for t in req.tools
        ]

    start = time.time()
    try:
        response = litellm.completion(
            model=model_id,
            messages=messages,
            temperature=float(req.temperature),
            max_tokens=req.max_tokens,
            tools=tools_payload,
            tool_choice=req.tool_choice if tools_payload else None,
            api_key=self.api_key,
            response_format={"type": "json_object"}
            if req.response_format == "json"
            else None,
        )

    except Exception as e:
        raise RuntimeError(f"LiteLLM Groq completion failed: {e}") from e

    choice = response.choices[0]
    content_text = choice.message.content or ""
    tool_calls = []

    if hasattr(choice.message, "tool_calls") and choice.message.tool_calls:
        for tc in choice.message.tool_calls:
            tool_calls.append(
                ToolCall(
                    id=tc.id,
                    name=tc.function.name,
                    arguments=tc.function.arguments,
                )
            )

    usage = getattr(response, "usage", {})
    latency = int((time.time() - start) * 1000)

    return ChatResponse(
        content=content_text,
        tool_calls=tool_calls,
        model=model_id,
        provider=self.name,
        prompt_tokens=getattr(usage, "prompt_tokens", 0),
        completion_tokens=getattr(usage, "completion_tokens", 0),
        latency_ms=latency,
    )

list_models()

List available Groq models.

Source code in wintermute/ai/providers/groq_provider.py
57
58
59
60
61
62
63
64
65
66
67
68
69
def list_models(self) -> list[ModelInfo]:
    """List available Groq models."""
    return [
        ModelInfo(
            "groq/llama-3.3-70b-versatile", "llama-3.3", 128_000, True, True, True
        ),
        ModelInfo(
            "groq/llama-3.1-70b-versatile", "llama-3.1", 128_000, True, True, True
        ),
        ModelInfo(
            "groq/llama-3.1-8b-instant", "llama-3.1", 128_000, True, True, True
        ),
    ]

register(api_key=None, *, as_name='groq')

Register the GroqProvider with Wintermute AI LLM registry.

Source code in wintermute/ai/providers/groq_provider.py
149
150
151
152
153
154
def register(api_key: Optional[str] = None, *, as_name: str = "groq") -> None:
    """Register the GroqProvider with Wintermute AI LLM registry."""
    prov = GroqProvider(api_key=api_key, _name=as_name)
    from ..provider import llms

    llms.register(prov)