From 0d31f8d9aff2c75c5ec3755288dce67ff45864aa Mon Sep 17 00:00:00 2001 From: Joe Farrington Date: Sat, 14 Dec 2024 22:19:16 +0000 Subject: [PATCH] Fix _STEP and _RESET for TempVMPackingEnv --- or_gym/envs/classic_or/vmpacking.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/or_gym/envs/classic_or/vmpacking.py b/or_gym/envs/classic_or/vmpacking.py index 5a4475f3..dd84a96b 100644 --- a/or_gym/envs/classic_or/vmpacking.py +++ b/or_gym/envs/classic_or/vmpacking.py @@ -172,7 +172,7 @@ def __init__(self, *args, **kwargs): super().__init__() self.state = self.reset() - def step(self, action): + def _STEP(self, action): done = False pm_state = self.state["state"][:-1] demand = self.state["state"][-1, 1:] @@ -198,7 +198,8 @@ def step(self, action): # Remove process from PM if self.durations[process] == self.current_step: pm = self.assignment[process] # Find PM where process was assigned - pm_state[pm, self.load_idx] -= self.demand[process] + pm_state[pm, self.load_idx] -= self.demand[process][1:] # Index to exclude first element of demand array + pm_state[pm, self.load_idx] = np.where(pm_state[pm, self.load_idx]1,1,data_center) # Fix rounding errors self.state["state"] = data_center - self.state["action_mask"] = np.ones(self.n_pms) - self.state["avail_actions"] = np.ones(self.n_pms) + self.state["action_mask"] = np.ones(self.n_pms, dtype=np.uint8) + self.state["avail_actions"] = np.ones(self.n_pms, dtype=np.uint8) if self.mask: action_mask = (pm_state[:, 1:] + self.demand[step, 1:]) <= 1 - self.state["action_mask"] = (action_mask.sum(axis=1)==2).astype(int) + self.state["action_mask"] = (action_mask.sum(axis=1)==2).astype(np.uint8) def _RESET(self): self.current_step = 0 self.assignment = {} self.demand = self.generate_demand() self.durations = generate_durations(self.demand) - self.state = (np.zeros((self.n_pms, 3)), self.demand[0]) + + self.state = { + "action_mask": np.ones(self.n_pms, dtype=np.uint8), + "avail_actions": np.ones(self.n_pms, dtype=np.uint8), + "state": np.vstack([ + np.zeros((self.n_pms, 3)), + self.demand[self.current_step]], + dtype=np.float32) + } return self.state def step(self, action):