from flask import Flask, request, jsonify
from openai import OpenAI
import os
import json
import logging
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger('databricks-proxy')
app = Flask(__name__)
# Databricks token and endpoint
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN", "****")
DATABRICKS_BASE_URL = os.environ.get(
"DATABRICKS_BASE_URL",
"https://****.net/serving-endpoints"
)
DATABRICKS_MODEL = os.environ.get("DATABRICKS_MODEL", "databricks-claude-3-7-sonnet")
# Initialize the Databricks client
client = OpenAI(
api_key=DATABRICKS_TOKEN,
base_url=DATABRICKS_BASE_URL
)
logger.info(f"Proxy configured for model: {DATABRICKS_MODEL}")
logger.info(f"Using base URL: {DATABRICKS_BASE_URL}")
u/app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
"""
Proxy endpoint for chat completions that forwards to Databricks Claude
"""
try:
# Get the request data
data = request.json
# Extract request parameters (use defaults if not provided)
messages = data.get('messages', [])
stream = data.get('stream', False)
# Log the incoming request
logger.info(f"Received chat completion request, streaming={stream}")
# Forward the request to Databricks
response = client.chat.completions.create(
model=DATABRICKS_MODEL,
messages=messages,
stream=stream
)
# If streaming is enabled, we need special handling
if stream:
def generate():
for chunk in response:
# Convert the chunk to a dictionary
chunk_dict = {
"id": chunk.id,
"object": "chat.completion.chunk",
"created": chunk.created,
"model": DATABRICKS_MODEL,
"choices": [
{
"index": c.index,
"delta": {
"role": c.delta.role if c.delta.role else None,
"content": c.delta.content if c.delta.content else None
},
"finish_reason": c.finish_reason
} for c in chunk.choices
]
}
# Use json.dumps instead of jsonify to avoid app context issues
yield f"data: {json.dumps(chunk_dict)}\n\n"
yield "data: [DONE]\n\n"
return app.response_class(generate(), mimetype='text/event-stream')
else:
# For non-streaming, return the complete response
return jsonify({
"id": response.id,
"object": "chat.completion",
"created": response.created,
"model": DATABRICKS_MODEL,
"choices": [
{
"index": c.index,
"message": {
"role": c.message.role,
"content": c.message.content
},
"finish_reason": c.finish_reason
} for c in response.choices
],
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
"total_tokens": response.usage.total_tokens
}
})
except Exception as e:
logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
return jsonify({"error": str(e)}), 500
u/app.route('/v1/models', methods=['GET'])
def list_models():
"""
Provide a fake models list endpoint for compatibility
"""
logger.info("Received models list request")
return jsonify({
"object": "list",
"data": [
{
"id": DATABRICKS_MODEL,
"object": "model",
"created": 1677610602,
"owned_by": "databricks"
}
]
})
u/app.route('/health', methods=['GET'])
def health_check():
"""
Health check endpoint for monitoring
"""
return jsonify({"status": "healthy"})
if __name__ == '__main__':
# Set default port
port = int(os.environ.get("PORT", 8000))
# Run the Flask app
logger.info(f"Starting proxy server at http://localhost:{port}")
logger.info(f"Forwarding requests to {DATABRICKS_BASE_URL}")
# Add environment variable instructions
if os.environ.get("DATABRICKS_TOKEN") is None:
logger.warning("DATABRICKS_TOKEN environment variable not set, using default value")
logger.warning("For better security, set the DATABRICKS_TOKEN environment variable")
app.run(host='0.0.0.0', port=port, debug=True)