Part of the [[Distillery Master]] The `create_routine` function is an asynchronous function that receives a tuple popped from the GenerationQueue, processes the request based on its type (inference or training), and pushes the generated images or training files to the SendQueue for further processing. ### Create Routine Main Function ```python async def create_routine(tuple): global generationqueue_pop_counter request_id = "N/A" try: aws_manager = await AWSManager.get_instance() request_id = tuple[0] username = tuple[1] generation_input_timestamp = tuple[2] payload = tuple[3] generation_command_args = tuple[4] message_data = tuple[5] generation_other_data = tuple[6] generation_output_timespentingenerationqueue = tuple[7] if 'is_training' in generation_other_data: if generation_other_data['is_training'] == True: generation_output, processing_time = await process_training(request_id, username, payload, generation_command_args, aws_manager) else: generation_output, processing_time = await process_inference(request_id, username, payload, generation_command_args, aws_manager) generation_output_timetogenerateimagefile = processing_time - generation_output_timespentingenerationqueue - generation_input_timestamp await aws_manager.push_send_queue(request_id, username, generation_input_timestamp, payload, generation_command_args, message_data, generation_other_data, generation_output, generation_output_timespentingenerationqueue, generation_output_timetogenerateimagefile) aws_manager.print_log(request_id, INSTANCE_IDENTIFIER, f"Create_Routine: images pushed to SendQueue. Images: {generation_output}", level='INFO') except Exception as e: formatted_exception = better_exceptions.format_exception(*sys.exc_info()) formatted_traceback = ''.join(formatted_exception) aws_manager.print_log(request_id, INSTANCE_IDENTIFIER, formatted_traceback, level='ERROR') finally: generationqueue_pop_counter -= 1 if generationqueue_pop_counter == MAX_GENERATIONQUEUE_POP_COUNT - 1 and MAX_GENERATIONQUEUE_POP_COUNT > 1: aws_manager.print_log("N/A", INSTANCE_IDENTIFIER, f"Create_Routine: GenerationQueue reduced to below {MAX_GENERATIONQUEUE_POP_COUNT} (MAX_GENERATIONQUEUE_POP_COUNT).", level='WARNING') if generationqueue_pop_counter == 0: aws_manager.print_log("N/A", INSTANCE_IDENTIFIER, f"Create_Routine: GenerationQueue reduced to zero.", level='INFO') ``` The main `create_routine` function does the following: 1. It retrieves an instance of the [[AWSManager]] class. 2. It unpacks the tuple popped from the [[GenerationQueue]] to extract various parameters. 3. It checks if the `generation_other_data` contains an `'is_training'` key: - If `'is_training'` is `True`, it calls the `process_training` function to handle the training request. - Otherwise, it calls the `process_inference` function to handle the inference request. 4. It calculates the time taken to generate the image file. 5. It pushes the generated images and relevant data to the [[SendQueue]] using `aws_manager.push_send_queue`. 6. It logs a message indicating that the images have been pushed to the [[SendQueue]]. 7. If an exception occurs during the execution, it formats the exception using `better_exceptions.format_exception` and logs the formatted traceback as an error. 8. In the `finally` block, it decrements the `generationqueue_pop_counter` and logs messages based on the counter value. ### Process Inference for Master Create Routine ```python async def process_inference(request_id, username, payload, generation_command_args, aws_manager): total_batches = generation_command_args['TOTAL_BATCHES'] image_urls = [] starting_seed = int(payload['template_inputs']['NOISE_SEED']) async def fetch_image(i): local_payload = copy.deepcopy(payload) new_seed = starting_seed + i * int(generation_command_args['IMG_PER_BATCH']) local_payload['comfy_api'] = PayloadBuilder.update_paths(local_payload['comfy_api'], local_payload['noise_seed_template_paths'], str(new_seed)) local_payload['template_inputs']['NOISE_SEED'] = str(new_seed) aws_manager.print_log(request_id, INSTANCE_IDENTIFIER, f"batch {i+1} of {total_batches} - Sending to Runpod - payload['template_inputs'] = {payload['template_inputs']}", level='INFO') image_files = await call_runpod(request_id, local_payload, generation_command_args) return image_files tasks = [fetch_image(i) for i in range(total_batches)] images = await asyncio.gather(*tasks) image_urls = flatten_list(images) generation_output = json.dumps(image_urls) return generation_output, time.time() ``` The `process_inference` function handles the processing of inference requests: 1. It extracts the total number of batches and the starting seed from the `generation_command_args` and `payload`. 2. It defines an asynchronous `fetch_image` function that: - Creates a [[deep copy]] of the payload. - Calculates a new seed based on the batch number and the `IMG_PER_BATCH` parameter. - Updates the `comfy_api` and `template_inputs` with the new seed. - Logs a message indicating the batch being sent to Runpod. - Calls the `call_runpod` function to generate the images. 3. It creates a list of tasks by calling `fetch_image` for each batch. 4. It awaits the completion of all tasks using `asyncio.gather`. 5. It flattens the list of generated images using the `flatten_list` function. 6. It converts the image URLs to JSON format. 7. It returns the generated output and the current timestamp. ### Process Inference for Master Create Routine ```python async def process_training(request_id, username, payload, generation_command_args, aws_manager): lora_file_dict = await call_runpod(request_id, payload, generation_command_args) if not lora_file_dict: raise Exception("LORA file dict is empty.") lora_name = lora_file_dict['lora_name'] default_strength_model = payload['preset_dict']['LORA_STRENGTH_MODEL'] default_strength_clip = payload['preset_dict']['LORA_STRENGTH_CLIP'] lora_model_is_private = payload['preset_dict']['LORA_DEFAULT_IS_PRIVATE'] if 'ispublic' not in payload['parsed_output'] else False lora_group = payload['preset_dict']['LORA_GROUP'] lora_model_type = payload['preset_dict']['MODEL_TYPE'] lora_model_file_name = lora_file_dict['lora_model_file_name'] lora_model_owner = username lora_file_dict['description'] = "" await aws_manager.add_lora_model(lora_name, lora_model_file_name, lora_model_type, default_strength_model, default_strength_clip, lora_model_is_private, lora_model_owner, lora_model_other_data=lora_file_dict) await aws_manager.add_lora_to_group(lora_name, lora_group) if lora_model_is_private == True: await aws_manager.add_lora_credentials(lora_name, username) return json.dumps(lora_file_dict), time.time() ``` The `process_training` function handles the processing of training requests: 1. It calls the `call_runpod` function to generate the [[LORA]] file and retrieve the [[LORA]] file dictionary. 2. It raises an exception if the LORA file dictionary is empty. 3. It extracts relevant information from the LORA file dictionary and the payload, such as the LORA name, default strengths, privacy setting, group, model type, file name, and owner. 4. It adds the LORA model to the database using `aws_manager.add_lora_model`. 5. It adds the LORA model to the specified group using `aws_manager.add_lora_to_group`. 6. If the LORA model is private, it adds the LORA credentials using `aws_manager.add_lora_credentials`. 7. It returns the LORA file dictionary as JSON and the current timestamp.