1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| import tritonclient.grpc as grpcclient import tritonclient.utils.shared_memory as shm import asyncio from asyncio import Queue from typing import Awaitable import logging import os from pathlib import Path
class ShmRegion(object): def __init__(self, triton_client: grpcclient, shm_queue: Queue, max_data_size, shm_name, shm_key): self.name = shm_name self.key = shm_key self.shm_queue = shm_queue self.size = max_data_size self.triton_client: grpcclient = triton_client self.shm_path = Path(f"/dev/shm/{self.key}") if self.shm_path.exists(): os.remove(self.shm_path) self.triton_client.unregister_system_shared_memory(self.name) self.handle = shm.create_shared_memory_region(self.name, self.key, max_data_size) self.triton_client.register_system_shared_memory(self.name, self.key, max_data_size) logging.info(f"shm region {self.name} registered")
def addToQueue(self, shm_queue=None): if shm_queue is not None: assert self.shm_queue is None or self.shm_queue is shm_queue self.shm_queue = shm_queue else: assert self.shm_queue is not None self.shm_queue.put_nowait(self)
def __enter__(self): return self
def __exit__(self, type, value, trace): self.addToQueue()
def __del__(self): logging.info(f"shm region {self.name} removed") self.triton_client.unregister_system_shared_memory(self.name) shm.destroy_shared_memory_region(self.handle)
class ShmTritonClient(object): def __init__(self, triton_client, max_queue_size, max_data_size, shm_name_prefix, shm_key_prefix): self.triton_client = triton_client self.shm_queue = None self.max_queue_size = max_queue_size self.regions = [] self.registered_regions = 0 for i in range(max_queue_size): region = ShmRegion(self.triton_client, self.shm_queue, max_data_size, f"{shm_name_prefix}_{i}", f"{shm_key_prefix}_{i}") self.regions.append(region) def getRegion(self) -> Awaitable[ShmRegion] : """每次调用时如果还有未被注册的region, 则注册一个, """ if self.shm_queue is None: loop = asyncio.get_running_loop() self.shm_queue = Queue(maxsize=self.max_queue_size, loop=loop) if len(self.regions) > self.registered_regions: self.regions[self.registered_regions].addToQueue(self.shm_queue) self.registered_regions += 1 return self.shm_queue.get()
|