def train_tsp(args, w1=1, w2=0, checkpoint=None): # Goals from paper: # TSP20, 3.97 # TSP50, 6.08 # TSP100, 8.44 from tasks import motsp from tasks.motsp import TSPDataset STATIC_SIZE = 4 # (x, y) DYNAMIC_SIZE = 1 # dummy for compatibility train_data = TSPDataset(args.num_nodes, args.train_size, args.seed) valid_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 1) update_fn = None actor = DRL4TSP(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size, update_fn, motsp.update_mask, args.num_layers, args.dropout).to(device) critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size).to(device) kwargs = vars(args) kwargs['train_data'] = train_data kwargs['valid_data'] = valid_data kwargs['reward_fn'] = motsp.reward kwargs['render_fn'] = motsp.render if checkpoint: path = os.path.join(checkpoint, 'actor.pt') actor.load_state_dict(torch.load(path, device)) # actor.static_encoder.state_dict().get("conv.weight").size() path = os.path.join(checkpoint, 'critic.pt') critic.load_state_dict(torch.load(path, device)) if not args.test: train(actor, critic, w1, w2, **kwargs) test_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 2) test_dir = 'test' test_loader = DataLoader(test_data, args.valid_size, False, num_workers=0) out = validate(test_loader, actor, motsp.reward, w1, w2, motsp.render, test_dir, num_plot=5) print('w1=%2.2f,w2=%2.2f. Average tour length: ' % (w1, w2), out)
1, 0.1).to(device) critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, 128).to(device) # data 143 from Post_process.convet_kro_dataloader import Kro_dataset kro = 1 D = 200 if kro: D = 200 Test_data = Kro_dataset(D) Test_loader = DataLoader(Test_data, 1, False, num_workers=0) else: # 40city_train: city20 13 city40 143 city70 2523 # Test_data = TSPDataset(D, 1, 2523) Test_loader = DataLoader(Test_data, 1, False, num_workers=0) iter_data = iter(Test_loader) static, dynamic, x0 = iter_data.next() static = static.to(device) dynamic = dynamic.to(device) x0 = x0.to(device) if len(x0) > 0 else None # load 50 models N=100 w = np.arange(N+1)/N objs = np.zeros((N+1,2)) start = time.time() t1_all = 0 t2_all = 0
from tasks.motsp import TSPDataset, reward from torch.utils.data import DataLoader import torch train_data = TSPDataset(10, 10000, 1234) train_loader = DataLoader(train_data, 100, True, num_workers=0) iter_data = iter(train_loader) batch = iter_data.next()[0] print(reward(batch, torch.randperm(10).expand(1, 10), 1, 0))