from flask import Flask, request, jsonify
from openai import OpenAI
import os
import json
import logging

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger('databricks-proxy')

app = Flask(__name__)

# Databricks token and endpoint
DATABRICKS_TOKEN = os.environ.get("DATABRICKS_TOKEN", "****")
DATABRICKS_BASE_URL = os.environ.get(
    "DATABRICKS_BASE_URL", 
    "https://****.net/serving-endpoints"
)
DATABRICKS_MODEL = os.environ.get("DATABRICKS_MODEL", "databricks-claude-3-7-sonnet")

# Initialize the Databricks client
client = OpenAI(
    api_key=DATABRICKS_TOKEN,
    base_url=DATABRICKS_BASE_URL
)

logger.info(f"Proxy configured for model: {DATABRICKS_MODEL}")
logger.info(f"Using base URL: {DATABRICKS_BASE_URL}")

u/app.route('/v1/chat/completions', methods=['POST'])
def chat_completions():
    """
    Proxy endpoint for chat completions that forwards to Databricks Claude
    """
    try:
        # Get the request data
        data = request.json
        
        # Extract request parameters (use defaults if not provided)
        messages = data.get('messages', [])
        stream = data.get('stream', False)
        
        # Log the incoming request
        logger.info(f"Received chat completion request, streaming={stream}")
        
        # Forward the request to Databricks
        response = client.chat.completions.create(
            model=DATABRICKS_MODEL,
            messages=messages,
            stream=stream
        )
        
        # If streaming is enabled, we need special handling
        if stream:
            def generate():
                for chunk in response:
                    # Convert the chunk to a dictionary
                    chunk_dict = {
                        "id": chunk.id,
                        "object": "chat.completion.chunk",
                        "created": chunk.created,
                        "model": DATABRICKS_MODEL,
                        "choices": [
                            {
                                "index": c.index,
                                "delta": {
                                    "role": c.delta.role if c.delta.role else None,
                                    "content": c.delta.content if c.delta.content else None
                                },
                                "finish_reason": c.finish_reason
                            } for c in chunk.choices
                        ]
                    }
                    # Use json.dumps instead of jsonify to avoid app context issues
                    yield f"data: {json.dumps(chunk_dict)}\n\n"
                yield "data: [DONE]\n\n"
            
            return app.response_class(generate(), mimetype='text/event-stream')
        else:
            # For non-streaming, return the complete response
            return jsonify({
                "id": response.id,
                "object": "chat.completion",
                "created": response.created,
                "model": DATABRICKS_MODEL,
                "choices": [
                    {
                        "index": c.index,
                        "message": {
                            "role": c.message.role,
                            "content": c.message.content
                        },
                        "finish_reason": c.finish_reason
                    } for c in response.choices
                ],
                "usage": {
                    "prompt_tokens": response.usage.prompt_tokens,
                    "completion_tokens": response.usage.completion_tokens,
                    "total_tokens": response.usage.total_tokens
                }
            })
    except Exception as e:
        logger.error(f"Error in chat completion: {str(e)}", exc_info=True)
        return jsonify({"error": str(e)}), 500

u/app.route('/v1/models', methods=['GET'])
def list_models():
    """
    Provide a fake models list endpoint for compatibility
    """
    logger.info("Received models list request")
    return jsonify({
        "object": "list",
        "data": [
            {
                "id": DATABRICKS_MODEL,
                "object": "model",
                "created": 1677610602,
                "owned_by": "databricks"
            }
        ]
    })

u/app.route('/health', methods=['GET'])
def health_check():
    """
    Health check endpoint for monitoring
    """
    return jsonify({"status": "healthy"})

if __name__ == '__main__':
    # Set default port
    port = int(os.environ.get("PORT", 8000))
    
    # Run the Flask app
    logger.info(f"Starting proxy server at http://localhost:{port}")
    logger.info(f"Forwarding requests to {DATABRICKS_BASE_URL}")
    
    # Add environment variable instructions
    if os.environ.get("DATABRICKS_TOKEN") is None:
        logger.warning("DATABRICKS_TOKEN environment variable not set, using default value")
        logger.warning("For better security, set the DATABRICKS_TOKEN environment variable")
    
    app.run(host='0.0.0.0', port=port, debug=True)