pi: (batch, city_t), predicted tour return: (batch) """ log_p = torch.gather(input=_log_p, dim=2, index=pi[:, :, None]) return torch.sum(log_p.squeeze(-1), 1) if __name__ == '__main__': cfg = load_pkl(pkl_parser().path) model = PtrNet1(cfg) inputs = torch.randn(3, 20, 2) pi, ll = model(inputs, device='cpu') print('pi:', pi.size(), pi) print('log_likelihood:', ll.size(), ll) cnt = 0 for i, k in model.state_dict().items(): print(i, k.size(), torch.numel(k)) cnt += torch.numel(k) print('total parameters:', cnt) # ll.mean().backward() # print(model.W_q.weight.grad) cfg.batch = 3 env = Env_tsp(cfg) cost = env.stack_l(inputs, pi) print('cost:', cost.size(), cost) cost = env.stack_l_fast(inputs, pi) print('cost:', cost.size(), cost)
if (ave_L / (i + 1) < min_L): min_L = ave_L / (i + 1) else: cnt += 1 print(f'cnt: {cnt}/20') if (cnt >= 20): print('early stop, average cost cant decrease anymore') if log_path is not None: with open(log_path, 'a') as f: f.write('\nearly stop') break t1 = time() if cfg.issaver: torch.save(act_model.state_dict(), cfg.model_dir + '%s_%s_step%d_act.pt' % (cfg.task, date, i)) # 'cfg.model_dir = ./Pt/' print('save model...') if __name__ == '__main__': cfg = load_pkl(pkl_parser().path) env = Env_tsp(cfg) if cfg.mode in ['train', 'train_emv']: # train_emv --> exponential moving average, not use critic model train_model(cfg, env) else: raise NotImplementedError( 'train and train_emv only, specify train pkl file')