"""
Input Processing Lambda - Downloads, validates JSON files, and creates batches using RDS Data API
"""
import json
import boto3
import logging
import os
from datetime import datetime
from typing import Dict, List, Any

logger = logging.getLogger()
logger.setLevel(logging.INFO)


s3_client = boto3.client('s3')
cloudwatch = boto3.client('cloudwatch')
rds_client = boto3.client('rds-data')

def execute_sql(sql, parameters=None):
    """Execute SQL using RDS Data API"""
    try:
        request = {
            'resourceArn': os.environ['DB_CLUSTER_ARN'],
            'secretArn': os.environ['DB_SECRET_ARN'],
            'database': os.environ['DB_NAME'],
            'sql': sql
        }
        
        if parameters:
            request['parameters'] = parameters
            
        response = rds_client.execute_statement(**request)
        return response
        
    except Exception as e:
        logger.error(f"Failed to execute SQL: {str(e)}")
        logger.error(f"SQL: {sql}")
        raise

def store_batch_info(batch_name: str, operation: str, total_devices: int, s3_url: str):
    """Store batch information in database using RDS Data API"""
    try:
        sql = """
            INSERT INTO batches (batch_id, operation, total_devices, s3_url, status, created_at)
            VALUES (:batch_id, :operation, :total_devices, :s3_url, 'PROCESSING', CURRENT_TIMESTAMP)
            ON CONFLICT (batch_id) DO UPDATE SET
                operation = EXCLUDED.operation,
                total_devices = EXCLUDED.total_devices,
                s3_url = EXCLUDED.s3_url,
                status = 'PROCESSING',
                updated_at = CURRENT_TIMESTAMP
        """
        
        parameters = [
            {'name': 'batch_id', 'value': {'stringValue': batch_name}},
            {'name': 'operation', 'value': {'stringValue': operation}},
            {'name': 'total_devices', 'value': {'longValue': total_devices}},
            {'name': 's3_url', 'value': {'stringValue': s3_url}}
        ]
        
        execute_sql(sql, parameters)
        logger.info(f"Stored batch info for {batch_name}")
    except Exception as e:
        logger.error(f"Failed to store batch info: {str(e)}")

def emit_metric(metric_name: str, value: float, unit: str = 'Count', dimensions: Dict[str, str] = None):
    """Emit CloudWatch metric"""
    try:
        metric_data = {
            'MetricName': metric_name,
            'Value': value,
            'Unit': unit,
            'Timestamp': datetime.utcnow()
        }
        
        if dimensions:
            metric_data['Dimensions'] = [
                {'Name': k, 'Value': v} for k, v in dimensions.items()
            ]
        
        cloudwatch.put_metric_data(
            Namespace='IoTWireless/BulkManagement',
            MetricData=[metric_data]
        )
        
    except Exception as e:
        logger.error(f"Failed to emit metric {metric_name}: {str(e)}")

def download_s3_file(bucket: str, key: str) -> str:
    """Download file from S3 and return content"""
    try:
        response = s3_client.get_object(Bucket=bucket, Key=key)
        content = response['Body'].read().decode('utf-8')
        logger.info(f"Downloaded file s3://{bucket}/{key} ({len(content)} bytes)")
        return content
    except Exception as e:
        logger.error(f"Failed to download s3://{bucket}/{key}: {str(e)}")
        raise

def validate_device_data(device: Dict[str, Any], operation: str) -> Dict[str, Any]:
    """Validate individual device data"""
    errors = []
    
    if operation == 'create':
        # Required fields for create operations
        if not device.get('smsn'):
            errors.append("Missing required field: smsn")
        
        # deviceProfileId is required (accept both camelCase and snake_case)
        if not device.get('deviceProfileId') and not device.get('device_profile_id'):
            errors.append("Missing required field: deviceProfileId (or device_profile_id)")
        
        # uplinkDestinationName is required for create operations
        if not device.get('uplinkDestinationName') and not device.get('uplink_destination_name'):
            errors.append("Missing required field: uplinkDestinationName (or uplink_destination_name)")
        
        # deviceName is optional for create operations
        device_name = device.get('deviceName') or device.get('device_name', '')
        if device_name and len(device_name) > 255:
            errors.append(f"Device name too long: {device_name}")
            
    elif operation == 'update':
        # For update operations, require at least one identifier
        has_smsn = bool(device.get('smsn'))
        has_aws_id = bool(device.get('awsWirelessDeviceId') or device.get('aws_wireless_device_id'))
        
        if not has_smsn and not has_aws_id:
            errors.append("Missing required identifier: at least one of 'smsn' or 'awsWirelessDeviceId' must be provided")
        
        # Validate device name if provided
        device_name = device.get('deviceName') or device.get('device_name', '')
        if device_name and len(device_name) > 255:
            errors.append(f"Device name too long: {device_name}")
    
    # Validate SMSN format if provided
    smsn = device.get('smsn', '')
    if smsn and len(smsn) != 64:
        errors.append(f"Invalid SMSN format (must be exactly 64 characters): {smsn}")
    
    return {
        'valid': len(errors) == 0,
        'errors': errors,
        'device': device
    }

def validate_json_structure(data: Any, operation: str) -> Dict[str, Any]:
    """Validate JSON structure and content"""
    validation_result = {
        'valid': True,
        'errors': [],
        'valid_devices': [],
        'invalid_devices': [],
        'total_devices': 0
    }
    
    try:
        # Handle both formats: direct array or wrapped object
        devices_array = None
        
        if isinstance(data, list):
            # Direct array format: [{"smsn": "...", ...}, ...]
            devices_array = data
        elif isinstance(data, dict) and 'devices' in data:
            # Wrapped object format: {"operation": "create", "devices": [...]}
            devices_array = data['devices']
            if not isinstance(devices_array, list):
                validation_result['valid'] = False
                validation_result['errors'].append("'devices' property must contain an array")
                return validation_result
        else:
            validation_result['valid'] = False
            validation_result['errors'].append("JSON must contain either an array of devices or an object with a 'devices' array property")
            return validation_result
        
        if len(devices_array) == 0:
            validation_result['valid'] = False
            validation_result['errors'].append("Devices array is empty")
            return validation_result
        
        validation_result['total_devices'] = len(devices_array)
        
        # Validate each device
        for i, device in enumerate(devices_array):
            if not isinstance(device, dict):
                validation_result['invalid_devices'].append({
                    'index': i,
                    'device': device,
                    'errors': ['Device must be a JSON object']
                })
                continue
            
            device_validation = validate_device_data(device, operation)
            
            if device_validation['valid']:
                validation_result['valid_devices'].append(device_validation['device'])
            else:
                validation_result['invalid_devices'].append({
                    'index': i,
                    'device': device_validation['device'],
                    'errors': device_validation['errors']
                })
        
        # Overall validation passes if we have at least some valid devices
        if len(validation_result['valid_devices']) == 0:
            validation_result['valid'] = False
            validation_result['errors'].append("No valid devices found in the file")
        
        logger.info(f"Validation complete: {len(validation_result['valid_devices'])} valid, {len(validation_result['invalid_devices'])} invalid")
        
    except Exception as e:
        validation_result['valid'] = False
        validation_result['errors'].append(f"Validation error: {str(e)}")
    
    return validation_result

def store_batches_in_s3(batches: List[Dict], batch_name: str, bucket_name: str) -> List[Dict]:
    """
    Store batch data in S3 and return lightweight references for Step Functions
    This avoids Step Functions 256 KB payload size limit
    Optimized for large numbers of batches (100K+ devices)
    """
    batch_references = []
    
    try:
        # Use ThreadPoolExecutor for concurrent S3 uploads when dealing with many batches
        import concurrent.futures
        
        def upload_single_batch(batch_info):
            i, batch = batch_info
            s3_key = f"batches/{batch_name}/batch-{i + 1}.json"
            
            # Store batch data in S3
            s3_client.put_object(
                Bucket=bucket_name,
                Key=s3_key,
                Body=json.dumps(batch),
                ContentType='application/json'
            )
            
            # Create lightweight reference for Step Functions
            batch_reference = {
                'batchName': batch['batchName'],
                'operation': batch['operation'],
                'tpsPerLambda': batch['tpsPerLambda'],
                'deviceCount': len(batch['devices']),
                'startIndex': batch['startIndex'],
                'endIndex': batch['endIndex'],
                's3Bucket': bucket_name,
                's3Key': s3_key
            }
            return (i, batch_reference)
        
        # For large numbers of batches (>50), use concurrent uploads
        if len(batches) > 50:
            logger.info(f"Using concurrent uploads for {len(batches)} batches")
            
            # Use up to 20 concurrent uploads to avoid overwhelming S3
            max_workers = min(20, len(batches))
            uploaded_refs = {}
            
            with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
                future_to_index = {
                    executor.submit(upload_single_batch, (i, batch)): i 
                    for i, batch in enumerate(batches)
                }
                
                for future in concurrent.futures.as_completed(future_to_index):
                    try:
                        index, batch_ref = future.result(timeout=60)  # 60 second timeout per upload
                        uploaded_refs[index] = batch_ref
                    except Exception as e:
                        index = future_to_index[future]
                        logger.error(f"Failed to upload batch {index}: {str(e)}")
                        raise
            
            # Reconstruct batch_references in original order
            batch_references = [uploaded_refs[i] for i in range(len(batches))]
            
        else:
            # For smaller numbers of batches, use sequential uploads
            for i, batch in enumerate(batches):
                _, batch_ref = upload_single_batch((i, batch))
                batch_references.append(batch_ref)
            
        logger.info(f"Successfully stored {len(batches)} batches in S3 bucket {bucket_name}")
        return batch_references
        
    except Exception as e:
        logger.error(f"Failed to store batches in S3: {str(e)}")
        raise

def create_batches(devices: List[Dict], operation: str, batch_name: str) -> List[Dict]:
    """Split devices into batches for parallel processing based on operation type"""
    batches = []
    
    # Get operation-specific configuration
    if operation == 'create':
        batch_size = int(os.environ.get('CREATE_BATCH_SIZE', '100'))
        max_concurrency = int(os.environ.get('CREATE_MAX_CONCURRENCY', '5'))
        tps_per_lambda = int(os.environ.get('CREATE_TPS_PER_LAMBDA', '2'))
    else:  # update
        batch_size = int(os.environ.get('UPDATE_BATCH_SIZE', '100'))
        max_concurrency = int(os.environ.get('UPDATE_MAX_CONCURRENCY', '5'))
        tps_per_lambda = int(os.environ.get('UPDATE_TPS_PER_LAMBDA', '2'))
    
    total_devices = len(devices)
    
    # Calculate maximum safe batch size based on Lambda timeout (5 minutes = 300 seconds)
    # Each device takes 1/tps_per_lambda seconds to process
    max_safe_batch_size = tps_per_lambda * 300  # 300 seconds timeout
    
    # Ensure configured batch_size doesn't exceed timeout safety limit
    safe_batch_size = min(batch_size, max_safe_batch_size)
    
    # Calculate optimal batch size to maximize concurrency utilization
    if total_devices <= max_concurrency:
        # If we have fewer devices than max concurrency, create one batch per device
        optimal_batch_size = 1
        num_batches = total_devices
    else:
        # Try to create exactly max_concurrency batches to fully utilize parallelism
        optimal_batch_size = (total_devices + max_concurrency - 1) // max_concurrency  # Ceiling division
        
        # But don't exceed the safety limit
        if optimal_batch_size > safe_batch_size:
            # If optimal batch size exceeds safety limit, use safe batch size
            optimal_batch_size = safe_batch_size
            num_batches = (total_devices + optimal_batch_size - 1) // optimal_batch_size
        else:
            num_batches = max_concurrency
    
    logger.info(f"Batch optimization: {total_devices} devices, max_concurrency={max_concurrency}")
    logger.info(f"Safe batch size limit: {safe_batch_size} devices (timeout constraint)")
    logger.info(f"Optimal batch size: {optimal_batch_size} devices")
    logger.info(f"Will create {num_batches} batches")
    
    # Create batches with optimal size
    for i in range(0, total_devices, optimal_batch_size):
        batch_devices = devices[i:i + optimal_batch_size]
        batch = {
            'batchName': f"{batch_name}-{len(batches) + 1}",
            'devices': batch_devices,
            'startIndex': i,
            'endIndex': min(i + optimal_batch_size, total_devices),
            'operation': operation,
            'tpsPerLambda': tps_per_lambda
        }
        batches.append(batch)
    
    # Log processing strategy
    if len(batches) <= max_concurrency:
        logger.info(f"All {len(batches)} batches will run concurrently")
    else:
        waves = (len(batches) + max_concurrency - 1) // max_concurrency
        logger.info(f"Created {len(batches)} batches, will process in {waves} waves of {max_concurrency} concurrent batches")
    
    logger.info(f"Batch creation complete: {len(batches)} batches for {total_devices} devices (operation: {operation})")
    return batches

def determine_operation(devices: List[Dict]) -> str:
    """Determine operation type based on device data"""
    # Check first few devices to determine operation
    sample_size = min(5, len(devices))
    
    create_indicators = 0
    update_indicators = 0
    
    for device in devices[:sample_size]:
        if device.get('aws_wireless_device_id'):
            update_indicators += 1
        # Check both camelCase and snake_case field names
        if (device.get('deviceName') or device.get('device_name')) and \
           (device.get('deviceProfileId') or device.get('device_profile_id')):
            create_indicators += 1
    
    # Determine operation based on majority
    if update_indicators > create_indicators:
        return 'update'
    else:
        return 'create'

def lambda_handler(event, context):
    """
    Main handler for input processing and validation
    """
    logger.info("Input processing started")
    logger.info(f"Event: {json.dumps(event)}")
    
    try:
        # Extract S3 information from EventBridge event
        if 'detail' in event and 'bucket' in event['detail']:
            # EventBridge S3 event
            bucket_name = event['detail']['bucket']['name']
            object_key = event['detail']['object']['key']
        elif 'Records' in event:
            # Direct S3 event (fallback)
            record = event['Records'][0]
            bucket_name = record['s3']['bucket']['name']
            object_key = record['s3']['object']['key']
        else:
            # Manual invocation for testing
            bucket_name = event.get('bucket')
            object_key = event.get('key')
        
        if not bucket_name or not object_key:
            raise ValueError("Could not extract S3 bucket and key from event")
        
        logger.info(f"Processing file: s3://{bucket_name}/{object_key}")
        
        # Download and parse JSON file
        file_content = download_s3_file(bucket_name, object_key)
        
        try:
            json_data = json.loads(file_content)
        except json.JSONDecodeError as e:
            logger.error(f"Invalid JSON format: {str(e)}")
            emit_metric('ValidationFailure', 1, 'Count')
            return {
                'statusCode': 400,
                'error': f"Invalid JSON format: {str(e)}",
                'bucket': bucket_name,
                'key': object_key
            }
        
        # Determine operation type
        operation = 'create'  # default
        
        if isinstance(json_data, dict) and 'operation' in json_data:
            # Operation specified in wrapped format
            operation = json_data['operation']
        elif isinstance(json_data, list) and len(json_data) > 0:
            # Detect operation from device data
            operation = determine_operation(json_data)
        elif isinstance(json_data, dict) and 'devices' in json_data and len(json_data['devices']) > 0:
            # Detect operation from devices in wrapped format
            operation = determine_operation(json_data['devices'])
        
        logger.info(f"Detected operation: {operation}")
        
        # Validate JSON structure and content
        validation_result = validate_json_structure(json_data, operation)
        
        # Get batch name early (needed for both success and failure cases)
        if isinstance(json_data, dict) and 'batchName' in json_data:
            batch_name = json_data['batchName']
            logger.info(f"Using batch name from JSON: {batch_name}")
        else:
            # Fallback: create batch name from S3 key
            batch_name = object_key.replace('/', '_').replace('.json', '')
            logger.info(f"Using batch name from S3 key: {batch_name}")
        
        s3_url = f"s3://{bucket_name}/{object_key}"
        
        if not validation_result['valid']:
            logger.error(f"Validation failed: {validation_result['errors']}")
            emit_metric('ValidationFailure', 1, 'Count', {'Operation': operation})
            
            # Store batch info in database with FAILED status
            store_batch_info(batch_name, operation, 0, s3_url)
            
            # Store validation errors for report-notify to process
            report_bucket = os.environ.get('REPORT_BUCKET')
            validation_error_key = f"batches/{batch_name}/validation-errors.json"
            
            validation_error_data = {
                'batch_name': batch_name,
                'operation': operation,
                'status': 'VALIDATION_FAILED',
                'timestamp': datetime.utcnow().isoformat(),
                'source_file': s3_url,
                'validation_errors': validation_result['errors'],
                'total_devices': validation_result['total_devices'],
                'valid_devices': 0,
                'invalid_devices_count': len(validation_result['invalid_devices']),
                'invalid_devices': validation_result['invalid_devices']
            }
            
            try:
                s3_client.put_object(
                    Bucket=report_bucket,
                    Key=validation_error_key,
                    Body=json.dumps(validation_error_data, indent=2),
                    ContentType='application/json'
                )
                logger.info(f"Validation errors stored at s3://{report_bucket}/{validation_error_key}")
            except Exception as e:
                logger.error(f"Failed to store validation errors: {str(e)}")
            
            # Return structure that triggers report-notify Lambda with validation failure info
            return {
                'statusCode': 200,  # Return 200 so Step Functions continues to report-notify
                'validationFailed': True,
                'operation': operation,
                'batchName': batch_name,
                'totalDevices': 0,
                'batchCount': 0,
                'batches': [],
                's3Url': s3_url,
                'validationErrorsS3Key': validation_error_key,
                'validationSummary': {
                    'total_devices': validation_result['total_devices'],
                    'valid_devices': 0,
                    'invalid_devices': len(validation_result['invalid_devices']),
                    'validation_errors': validation_result['errors']
                }
            }
        
        # Store batch information in database
        store_batch_info(batch_name, operation, len(validation_result['valid_devices']), s3_url)
        
        # Create batches for parallel processing
        batches = create_batches(validation_result['valid_devices'], operation, batch_name)
        
        # Store batches in S3 to avoid Step Functions payload size limits (256 KB)
        # Use report bucket to avoid triggering EventBridge on batch file creation
        report_bucket = os.environ.get('REPORT_BUCKET')
        batch_references = store_batches_in_s3(batches, batch_name, report_bucket)
        
        # Get operation-specific configuration for response
        if operation == 'create':
            max_concurrency = int(os.environ.get('CREATE_MAX_CONCURRENCY', '5'))
            tps_per_lambda = int(os.environ.get('CREATE_TPS_PER_LAMBDA', '2'))
        else:  # update
            max_concurrency = int(os.environ.get('UPDATE_MAX_CONCURRENCY', '5'))
            tps_per_lambda = int(os.environ.get('UPDATE_TPS_PER_LAMBDA', '2'))
        
        # Emit metrics
        emit_metric('ValidationSuccess', 1, 'Count', {'Operation': operation})
        emit_metric('DevicesValidated', len(validation_result['valid_devices']), 'Count', {'Operation': operation})
        emit_metric('BatchProcessingStarted', len(batches), 'Count', {'Operation': operation})
        
        logger.info(f"Input processing completed successfully")
        logger.info(f"Valid devices: {len(validation_result['valid_devices'])}")
        logger.info(f"Batches created: {len(batches)}")
        
        return {
            'statusCode': 200,
            'operation': operation,
            'batchName': batch_name,
            'totalDevices': len(validation_result['valid_devices']),
            'batchCount': len(batches),
            'batches': batch_references,
            'tpsPerLambda': tps_per_lambda,
            'maxConcurrency': max_concurrency,
            's3Url': s3_url,
            'validationSummary': {
                'total_devices': validation_result['total_devices'],
                'valid_devices': len(validation_result['valid_devices']),
                'invalid_devices': len(validation_result['invalid_devices'])
            }
        }
        
    except Exception as e:
        logger.error(f"Input processing failed: {str(e)}")
        emit_metric('InputProcessingFailure', 1, 'Count')
        return {
            'statusCode': 500,
            'error': str(e)
        }