Пример #1
0
    def update(self, selected):

        # Update the state
        prev_a = selected[:, None]  # Add dimension for step

        # Add the length
        # cur_coord = self.loc.gather(
        #     1,
        #     selected[:, None, None].expand(selected.size(0), 1, self.loc.size(-1))
        # )[:, 0, :]
        cur_coord = self.loc[self.ids, prev_a]
        lengths = self.lengths
        if self.cur_coord is not None:  # Don't add length for first action (selection of start node)
            lengths = self.lengths + (cur_coord - self.cur_coord).norm(
                p=2, dim=-1)  # (batch_dim, 1)

        # Update should only be called with just 1 parallel step, in which case we can check this way if we should update
        first_a = prev_a if self.i.item() == 0 else self.first_a

        if self.visited_.dtype == torch.bool:
            # Add one dimension since we write a single value
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
        else:
            visited_ = mask_long_scatter(self.visited_, prev_a)

        return self._replace(first_a=first_a,
                             prev_a=prev_a,
                             visited_=visited_,
                             lengths=lengths,
                             cur_coord=cur_coord,
                             i=self.i + 1)
Пример #2
0
    def update(self, selected):

        assert self.i.size(
            0) == 1, "Can only update if state represents single step"

        # Update the state
        selected = selected[:, None]  # Add dimension for step
        prev_a = selected

        # Add the length
        cur_coord = self.coords[self.ids, selected]
        lengths = self.lengths + (cur_coord - self.cur_coord).norm(
            p=2, dim=-1)  # (batch_dim, 1)

        # Add the collected prize
        cur_total_prize = self.cur_total_prize + self.prize[self.ids, selected]

        if self.visited_.dtype == torch.uint8:
            # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot
            # Add one dimension since we write a single value
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
        else:
            # This works, by check_unset=False it is allowed to set the depot visited a second a time
            visited_ = mask_long_scatter(self.visited_,
                                         prev_a,
                                         check_unset=False)

        return self._replace(prev_a=prev_a,
                             visited_=visited_,
                             lengths=lengths,
                             cur_coord=cur_coord,
                             cur_total_prize=cur_total_prize,
                             i=self.i + 1)
Пример #3
0
 def update(self, selected, car_id, step):
     # selected shape is: [batch_size, graph_size]
     # Update the state
     batch_size = selected.shape[0]
     prev_a = self.prev_a
     prev_a[car_id, ...] = selected[:, None].view(batch_size, -1).type(torch.LongTensor)
     # Update should only be called with just 1 parallel step,
     # in which case we can check this way if we should update
     first_a = self.first_a
     cur_coord = self.cur_coord
     if cur_coord is None:
         cur_coord = torch.zeros(self.n_cars, batch_size, 2, device=prev_a.device)
     visited_ = self.visited_
     lengths = self.lengths
     prev_a_ = prev_a[car_id, ...]
     cur_coord_ = self.loc[self.ids, prev_a_].view(batch_size, -1)
     if self.cur_coord is not None:  # Don't add length for first action (selection of start node)
         lengths = lengths + (cur_coord_ - self.cur_coord[car_id, ...]).norm(p=2, dim=-1).view(-1, 1)  # (batch_dim, 1)
     if self.visited_.dtype == torch.bool:
         # Add one dimension since we write a single value
         # add's 1 to wherever we visit now, this creates a vector of 1's wherever we have been already
         visited_ = visited_.scatter(-1, prev_a_[:, :, None], 1)
     else:
         visited_ = mask_long_scatter(self.visited_, prev_a_)
     if self.allow_repeated_choices:
         new_mask = self._update_mask(selected)
     else:
         new_mask = self.mask
     cur_coord[car_id, ...] = cur_coord_
     return self._replace(first_a=first_a, prev_a=prev_a,
                          cur_coord=cur_coord, i=torch.tensor(step),
                          visited_=visited_, lengths=lengths, mask=new_mask)
Пример #4
0
    def update(self, selected):

        assert self.i.size(
            0) == 1, "Can only update if state represents single step"

        # Update the state
        selected = selected[:, None]  # Add dimension for step
        prev_a = selected
        n_loc = self.demand.size(-1)  # Excludes depot

        # Add the length
        cur_coord = self.coords[self.ids, selected]
        cur_time = self.cur_time + (cur_coord - self.cur_coord).norm(p=2,
                                                                     dim=-1)
        # cur_coord = self.coords.gather(
        #     1,
        #     selected[:, None].expand(selected.size(0), 1, self.coords.size(-1))
        # )[:, 0, :]
        lengths = self.lengths + (cur_coord - self.cur_coord).norm(
            p=2, dim=-1)  # (batch_dim, 1)
        total_service_times = self.total_service_times + self.service_times[
            self.ids, selected]
        delay = self.time_window_finish[self.ids, selected] - cur_time
        early = cur_time - self.time_window_start[self.ids, selected]
        total_delay_times = self.total_delay_times + delay * (delay >
                                                              0).float()
        total_early_times = self.total_early_times + early * (early >
                                                              0).float()

        # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot!
        #selected_demand = self.demand.gather(-1, torch.clamp(prev_a - 1, 0, n_loc - 1))
        selected_demand = self.demand[self.ids,
                                      torch.clamp(prev_a - 1, 0, n_loc - 1)]

        # Increase capacity if depot is not visited, otherwise set to 0
        #used_capacity = torch.where(selected == 0, 0, self.used_capacity + selected_demand)
        used_capacity = (self.used_capacity +
                         selected_demand) * (prev_a != 0).float()

        if self.visited_.dtype == torch.uint8:
            # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot
            # Add one dimension since we write a single value
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
        else:
            # This works, will not set anything if prev_a -1 == -1 (depot)
            visited_ = mask_long_scatter(self.visited_, prev_a - 1)

        return self._replace(prev_a=prev_a,
                             used_capacity=used_capacity,
                             visited_=visited_,
                             lengths=lengths,
                             cur_coord=cur_coord,
                             i=self.i + 1,
                             total_service_times=total_service_times,
                             total_delay_times=total_delay_times,
                             total_early_times=total_early_times)
    def update(self, selected):

        assert self.i.size(
            0) == 1, "Can only update if state represents single step"

        # Update the state
        n_loc = self.to_delivery.size(-1) - 1  # number of customers
        new_to_delivery = (selected + n_loc // 2) % (
            n_loc + 1)  # the pair node of selected node
        new_to_delivery = new_to_delivery[:, None]
        selected = selected[:, None]  # Add dimension for step
        prev_a = selected
        #         n_loc = self.demand.size(-1)  # Excludes depot

        # Add the length
        cur_coord = self.coords[self.ids, selected]
        # cur_coord = self.coords.gather(
        #     1,
        #     selected[:, None].expand(selected.size(0), 1, self.coords.size(-1))
        # )[:, 0, :]
        lengths = self.lengths + (cur_coord - self.cur_coord).norm(
            p=2, dim=-1)  # (batch_dim, 1)

        # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot!
        #selected_demand = self.demand.gather(-1, torch.clamp(prev_a - 1, 0, n_loc - 1))
        #         selected_demand = self.demand[self.ids, torch.clamp(prev_a - 1, 0, n_loc - 1)]

        # Increase capacity if depot is not visited, otherwise set to 0
        #used_capacity = torch.where(selected == 0, 0, self.used_capacity + selected_demand)
        #         used_capacity = (self.used_capacity + selected_demand) * (prev_a != 0).float()

        if self.visited_.dtype == torch.uint8:
            # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot
            # Add one dimension since we write a single value
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
            to_delivery = self.to_delivery.scatter(-1, new_to_delivery[:, :,
                                                                       None],
                                                   1)
        else:
            # This works, will not set anything if prev_a -1 == -1 (depot)
            visited_ = mask_long_scatter(self.visited_, prev_a - 1)

        return self._replace(
            prev_a=prev_a,
            visited_=visited_,
            lengths=lengths,
            cur_coord=cur_coord,
            i=self.i + 1,
            to_delivery=to_delivery,
        )
Пример #6
0
    def initialize(input, allow_repeated_choices, visited_dtype=torch.bool):
        depot = input['depot'].clone()
        loc = input['loc'].clone()
        n_cars = input['n_cars'][0]
        batch_size, n_loc, _ = loc.size()
        prev_a = torch.zeros(n_cars, batch_size, 1, dtype=torch.long, device=loc.device)
        index_to_choices = create_index_to_choices(n_cars, n_loc, loc.device)
        index_size = index_to_choices.shape[0]
        mask = torch.zeros([batch_size, 1, index_size], dtype=torch.bool, device=loc.device)
        can_repeat = torch.ones([batch_size, n_cars])
        visited_ = (  # Visited as mask is easier to understand, as long more memory efficient
            torch.zeros(
                batch_size, 1, n_loc+1,
                dtype=torch.bool, device=loc.device
            )
            if visited_dtype == torch.bool
            else torch.zeros(batch_size, 1, (n_loc + 63) // 64, dtype=torch.int64, device=loc.device)  # Ceil
        )

        # mark first node as visited since it is now the depot and where all cars start
        prev_a_ = prev_a[0, ...]
        if visited_.dtype == torch.bool:
            # Add one dimension since we write a single value
            # add's 1 to wherever we visit now, this creates a vector of 1's wherever we have been already
            visited_ = visited_.scatter(-1, prev_a_[:, :, None], 1)
        else:
            visited_ = mask_long_scatter(visited_, prev_a_)
        return StateMTSP(
            loc=torch.cat((depot[:, None, :], loc), -2),
            dist=(loc[:, :, None, :] - loc[:, None, :, :]).norm(p=2, dim=-1),
            ids=torch.arange(batch_size, dtype=torch.int64, device=loc.device)[:, None],  # Add steps dimension
            first_a=torch.zeros_like(prev_a, device=loc.device),
            prev_a=prev_a,
            # Keep visited with depot so we can scatter efficiently (if there is an action for depot)
            visited_=visited_,
            lengths=torch.zeros(batch_size, 1, device=loc.device),
            cur_coord=input['depot'].clone()[None, :, :].expand([n_cars, batch_size, 2]),
            i=torch.zeros(1, dtype=torch.int64, device=loc.device),  # Vector with length num_steps
            n_cars=n_cars,
            index_to_choices=index_to_choices,
            allow_repeated_choices=allow_repeated_choices,
            can_repeat=can_repeat,
            mask=mask
        )
Пример #7
0
    def update(self, selected):
        '''
        selected: (batch_size)
        '''

        # Update the state
        # prev_a: (batch_size, 1)
        prev_a = selected[:, None]  # Add dimension for step

        # Update should only be called with just 1 parallel step, in which case we can check this way if we should update
        first_a = prev_a if self.i.item() == 0 else self.first_a

        if self.visited_.dtype == torch.uint8:
            # Add one dimension since we write a single value
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
        else:
            visited_ = mask_long_scatter(self.visited_, prev_a)

        return self._replace(first_a=first_a, prev_a=prev_a, visited_=visited_,
                             i=self.i + 1)
Пример #8
0
    def update(self, selected):

        # 直接用前两次的选点
        first_a = self.prev_a

        prev_a = selected[:, None]
        cur_coord = self.loc[self.ids, prev_a]
        lengths = self.lengths
        if self.cur_coord is not None:
            lengths = self.lengths + (cur_coord - self.cur_coord).norm(p=2,
                                                                       dim=-1)

        # first_a = prev_a if self.i.item() == 0 else self.first_a

        if self.visited_.dtype == torch.uint8:
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
        else:
            visited_ = mask_long_scatter(self.visited_, prev_a)
        return self._replace(first_a=first_a,
                             prev_a=prev_a,
                             visited_=visited_,
                             lengths=lengths,
                             cur_coord=cur_coord,
                             i=self.i + 1)
Пример #9
0
    def update(self, selected, vehicle_count):

        assert self.i.size(
            0) == 1, "Can only update if state represents single step"

        # Update the state
        vehicle_index = (selected % vehicle_count)[:, None]
        selected_node = selected // vehicle_count
        selected = selected_node[:, None]  # Add dimension for step
        prev_a = selected
        n_loc = self.demand.size(-1)  # Excludes depot

        # Add the length
        cur_coord = self.coords[self.ids, selected]
        # cur_coord = self.coords.gather(
        #     1,
        #     selected[:, None].expand(selected.size(0), 1, self.coords.size(-1))
        # )[:, 0, :]
        lengths = self.lengths[self.ids, 0, vehicle_index] + \
            (cur_coord - self.cur_coord[self.ids, 0, vehicle_index]).norm(p=2, dim=-1)  # (batch_dim, 1)

        # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot!
        #selected_demand = self.demand.gather(-1, torch.clamp(prev_a - 1, 0, n_loc - 1))
        selected_demand = self.demand[self.ids,
                                      torch.clamp(prev_a - 1, 0, n_loc - 1)]

        # Increase capacity if depot is not visited, otherwise set to 0
        #used_capacity = torch.where(selected == 0, 0, self.used_capacity + selected_demand)
        used_capacity = (self.used_capacity[self.ids, 0, vehicle_index] +
                         selected_demand) * (prev_a != 0).float()

        if self.visited_.dtype == torch.uint8:
            # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot
            # Add one dimension since we write a single value
            visited_locations = vehicle_count * prev_a[:, :, None]
            for i in range(1, vehicle_count):
                visited_locations = torch.cat(
                    (visited_locations,
                     vehicle_count * prev_a[:, :, None] + i), -1)

            visited_ = self.visited_.scatter(-1, visited_locations, 1)
        else:
            # This works, will not set anything if prev_a -1 == -1 (depot)
            visited_ = self.visited_
            for i in range(vehicle_count):
                visited_ = mask_long_scatter(
                    visited_, (vehicle_count * (prev_a - 1) + i).clamp(min=-1))

        prev_a_tmp = self.prev_a.scatter(-1, vehicle_index[:, :, None],
                                         prev_a[:, None, :])
        used_capacity_tmp = self.used_capacity.scatter(
            -1, vehicle_index[:, :, None], used_capacity[:, None, :])
        lengths_tmp = self.lengths.scatter(-1, vehicle_index[:, :, None],
                                           lengths[:, None, :])
        cur_coord_tmp = self.cur_coord.scatter(-2, vehicle_index[:, :, None,
                                                                 None],
                                               cur_coord[:, :, None, :])

        return self._replace(prev_a=prev_a_tmp,
                             used_capacity=used_capacity_tmp,
                             visited_=visited_,
                             lengths=lengths_tmp,
                             cur_coord=cur_coord_tmp,
                             i=self.i + 1)
Пример #10
0
    def update(self, selected, vehicle_count):

        assert self.i.size(0) == 1, "Can only update if state represents single step"
        # Update the state
        vehicle_index = (selected % vehicle_count)[:, None]
        selected_node = selected // vehicle_count
        selected = selected_node[:, None]  # Add dimension for step
        prev_a = selected
        n_loc = self.demand.size(-1)  # Excludes depot

        # Add the length
        cur_coord = self.coords[self.ids, selected]
        arrival_time = (
                (self.cur_time.gather(-1, vehicle_index) +
                 self.TIME_SCALE*(cur_coord - self.cur_coord[self.ids, 0, vehicle_index]).norm(p=2, dim=-1)) *\
                (self.prev_a[self.ids, 0, vehicle_index] != 0) +
                (self.time_window_start[self.ids, selected])*(self.prev_a[self.ids, 0, vehicle_index] == 0)
        )

        cur_time = torch.max(
            arrival_time,
            self.time_window_start[self.ids, selected]
        )

        lengths = self.lengths[self.ids, 0, vehicle_index] +\
            (cur_coord - self.cur_coord[self.ids, 0, vehicle_index]).norm(p=2, dim=-1)  # (batch_dim, 1)
        total_service_times = self.total_service_times[self.ids, 0, vehicle_index] +\
            self.service_times[self.ids, selected]
        delay = arrival_time - self.time_window_finish[self.ids, selected]
        early = self.time_window_start[self.ids, selected] - arrival_time
        total_delay_times = self.total_delay_times[self.ids, 0, vehicle_index] + delay*(delay > 0).float()
        total_early_times = self.total_early_times[self.ids, 0, vehicle_index] + early*(early > 0).float()

        # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot!
        #selected_demand = self.demand.gather(-1, torch.clamp(prev_a - 1, 0, n_loc - 1))
        selected_demand = self.demand[self.ids, torch.clamp(prev_a - 1, 0, n_loc - 1)]
        # Increase capacity if depot is not visited, otherwise set to 0
        #used_capacity = torch.where(selected == 0, 0, self.used_capacity + selected_demand)
        used_capacity = (self.used_capacity[self.ids, 0, vehicle_index] + selected_demand) * (prev_a != 0).float()
        if self.visited_.dtype == torch.uint8:
            # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot
            # Add one dimension since we write a single value
            visited_locations = vehicle_count*prev_a[:, :, None]
            for i in range(1, vehicle_count):
                visited_locations = torch.cat((visited_locations, vehicle_count * prev_a[:, :, None] + i), -1)

            visited_ = self.visited_.scatter(-1, visited_locations, 1)
        else:
            # This works, will not set anything if prev_a -1 == -1 (depot)
            visited_ = self.visited_
            for i in range(vehicle_count):
                visited_ = mask_long_scatter(visited_, (vehicle_count*(prev_a - 1) + i).clamp(min=-1))

        prev_a_tmp = self.prev_a.scatter(-1, vehicle_index[:, :, None], prev_a[:, None, :])
        used_capacity_tmp = self.used_capacity.scatter(-1, vehicle_index[:, :, None], used_capacity[:, None, :])
        lengths_tmp = self.lengths.scatter(-1, vehicle_index[:, :, None], lengths[:, None, :])
        cur_coord_tmp = self.cur_coord.scatter(-2, vehicle_index[:, :, None, None], cur_coord[:, :, None, :])
        total_service_times_tmp = self.total_service_times.scatter(-1, vehicle_index[:, :, None], total_service_times[:, None, :])
        total_early_times_tmp = self.total_early_times.scatter(-1, vehicle_index[:, :, None], total_early_times[:, None, :])
        total_delay_times_tmp = self.total_delay_times.scatter(-1, vehicle_index[:, :, None], total_delay_times[:, None, :])
        cur_time_tmp = self.cur_time.scatter(-1, vehicle_index, cur_time)

        return self._replace(
            prev_a=prev_a_tmp, used_capacity=used_capacity_tmp, visited_=visited_,
            lengths=lengths_tmp, cur_coord=cur_coord_tmp, i=self.i + 1, total_service_times=total_service_times_tmp,
            total_delay_times=total_delay_times_tmp, total_early_times=total_early_times_tmp, cur_time=cur_time_tmp
        )
Пример #11
0
    def update(self, selected):

        assert self.i.size(
            0) == 1, "Can only update if state represents single step"

        # Update the state
        selected = selected[:, None]  # Add dimension for step
        prev_a = selected
        n_loc = self.demand.size(-1)  # Excludes depot

        # Add the length
        cur_coord = self.coords[self.ids, selected]
        # cur_coord = self.coords.gather(
        #     1,
        #     selected[:, None].expand(selected.size(0), 1, self.coords.size(-1))
        # )[:, 0, :]
        lengths = self.lengths
        if self.cur_depot is not None:
            cur_depot = self.cur_depot.detach().clone()
        else:
            cur_depot = self.cur_depot
        if self.i > 0:
            lengths[~self.start_new_routes] = self.lengths[
                ~self.start_new_routes] + (
                    cur_coord[~self.start_new_routes] -
                    self.cur_coord[~self.start_new_routes]).norm(
                        p=2, dim=-1)  # (batch_dim, 1)

            # Not selected_demand is demand of first node (by clamp) so incorrect for nodes that visit depot!
            #selected_demand = self.demand.gather(-1, torch.clamp(prev_a - 1, 0, n_loc - 1))
            selected_demand = self.demand[
                self.ids,
                torch.clamp(prev_a - self.num_depots, 0, n_loc - 1)]

            # Increase capacity if depot is not visited, otherwise set to 0
            #used_capacity = torch.where(selected == 0, 0, self.used_capacity + selected_demand)
            used_capacity = (self.used_capacity + selected_demand) * (
                prev_a >= self.num_depots).float()
            cur_depot[self.start_new_routes] = selected[self.start_new_routes]

        else:
            used_capacity = self.used_capacity
            cur_depot = selected

        if self.visited_.dtype == torch.uint8:
            # Note: here we do not subtract one as we have to scatter so the first column allows scattering depot
            # Add one dimension since we write a single value
            visited_ = self.visited_.scatter(-1, prev_a[:, :, None], 1)
        else:
            # This works, will not set anything if prev_a -1 == -1 (depot)
            visited_ = mask_long_scatter(self.visited_, prev_a - 1)

        new_route_was_started = self.start_new_routes.detach().clone()
        start_new_routes = self.start_new_routes.detach().clone()
        start_new_routes[selected < self.num_depots] = True
        start_new_routes[self.start_new_routes] = False

        return self._replace(prev_a=prev_a,
                             used_capacity=used_capacity,
                             visited_=visited_,
                             lengths=lengths,
                             cur_coord=cur_coord,
                             cur_depot=cur_depot,
                             start_new_routes=start_new_routes,
                             new_route_was_started=new_route_was_started,
                             i=self.i + 1)