コード例 #1
0
def train_tsp(args, w1=1, w2=0, checkpoint=None):

    # Goals from paper:
    # TSP20, 3.97
    # TSP50, 6.08
    # TSP100, 8.44

    from tasks import motsp
    from tasks.motsp import TSPDataset

    STATIC_SIZE = 4  # (x, y)
    DYNAMIC_SIZE = 1  # dummy for compatibility

    train_data = TSPDataset(args.num_nodes, args.train_size, args.seed)
    valid_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 1)

    update_fn = None

    actor = DRL4TSP(STATIC_SIZE, DYNAMIC_SIZE, args.hidden_size, update_fn,
                    motsp.update_mask, args.num_layers,
                    args.dropout).to(device)

    critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE,
                         args.hidden_size).to(device)

    kwargs = vars(args)
    kwargs['train_data'] = train_data
    kwargs['valid_data'] = valid_data
    kwargs['reward_fn'] = motsp.reward
    kwargs['render_fn'] = motsp.render

    if checkpoint:
        path = os.path.join(checkpoint, 'actor.pt')
        actor.load_state_dict(torch.load(path, device))
        # actor.static_encoder.state_dict().get("conv.weight").size()
        path = os.path.join(checkpoint, 'critic.pt')
        critic.load_state_dict(torch.load(path, device))

    if not args.test:
        train(actor, critic, w1, w2, **kwargs)

    test_data = TSPDataset(args.num_nodes, args.valid_size, args.seed + 2)

    test_dir = 'test'
    test_loader = DataLoader(test_data, args.valid_size, False, num_workers=0)
    out = validate(test_loader,
                   actor,
                   motsp.reward,
                   w1,
                   w2,
                   motsp.render,
                   test_dir,
                   num_plot=5)

    print('w1=%2.2f,w2=%2.2f. Average tour length: ' % (w1, w2), out)
コード例 #2
0
                1,
                0.1).to(device)
critic = StateCritic(STATIC_SIZE, DYNAMIC_SIZE, 128).to(device)

# data 143
from Post_process.convet_kro_dataloader import Kro_dataset
kro = 1
D = 200
if kro:
    D = 200
    Test_data = Kro_dataset(D)
    Test_loader = DataLoader(Test_data, 1, False, num_workers=0)
else:
    # 40city_train: city20 13 city40 143 city70 2523
    #
    Test_data = TSPDataset(D, 1, 2523)
    Test_loader = DataLoader(Test_data, 1, False, num_workers=0)

iter_data = iter(Test_loader)
static, dynamic, x0 = iter_data.next()
static = static.to(device)
dynamic = dynamic.to(device)
x0 = x0.to(device) if len(x0) > 0 else None

# load 50 models
N=100
w = np.arange(N+1)/N
objs = np.zeros((N+1,2))
start  = time.time()
t1_all = 0
t2_all = 0
コード例 #3
0
ファイル: test.py プロジェクト: echofist/RL_TSP_4static
from tasks.motsp import TSPDataset, reward
from torch.utils.data import DataLoader
import torch

train_data = TSPDataset(10, 10000, 1234)
train_loader = DataLoader(train_data, 100, True, num_workers=0)
iter_data = iter(train_loader)
batch = iter_data.next()[0]
print(reward(batch, torch.randperm(10).expand(1, 10), 1, 0))