def init_process(local_rank, backend, config): os.environ['MASTER_ADDR'] = '127.0.0.1' os.environ['MASTER_PORT'] = '29500' dist.init_process_group(backend, rank=local_rank, world_size=config.num_gpus) torch.cuda.set_device(local_rank) torch.backends.cudnn.benchmark = True logger = logging.getLogger("DST") logger.setLevel(logging.INFO) stream_handler = logging.StreamHandler() logger.addHandler(stream_handler) if local_rank != 0: logger.setLevel(logging.WARNING) if local_rank == 0: writer = SummaryWriter() if not os.path.exists("save"): os.mkdir("save") save_path = "save/model_{}.pt".format( re.sub("\s+", "_", time.asctime())) random.seed(config.seed) vocab = Vocab(config) vocab.load("save/vocab") db = DB(config.data_path) reader = Reader(vocab, config) start = time.time() logger.info("Loading data...") reader.load_data("train") end = time.time() logger.info("Loaded. {} secs".format(end - start)) evaluator = MultiWOZEvaluator(reader, db, config) lr = config.lr model = DIALOG(vocab, db, config).cuda() optimizer = Adam(model.parameters(), lr=lr) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) # load saved model, optimizer if config.save_path is not None: load(model, optimizer, config.save_path) train.max_iter = len(list(reader.make_batch(reader.train))) validate.max_iter = len(list(reader.make_batch(reader.dev))) train.warmup_steps = train.max_iter * config.max_epochs * config.warmup_steps train.global_step = 0 max_joint_acc = 0 early_stop_count = config.early_stop_count loss, joint_acc, slot_acc = validate(model, reader, evaluator, config, local_rank) logger.info( "loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}".format( loss, joint_acc, slot_acc)) # loss, joint_acc, slot_acc, perplexity, inform_rate, success_rate, inform_test, success_test = validate(model, reader, evaluator, config, local_rank) # logger.info("loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}, perplexity: {:.4f}, inform rate: {:.4f}, success rate: {:.4f}, inform test: {:.4f}, success test: {:.4f}"\ # .format(loss, joint_acc, slot_acc, perplexity, inform_rate, success_rate, inform_test, success_test)) for epoch in range(config.max_epochs): logger.info("Train...") start = time.time() if local_rank == 0: train(model, reader, optimizer, writer, config, local_rank, evaluator) else: train(model, reader, optimizer, None, config, local_rank, evaluator) end = time.time() logger.info("epoch: {}, {:.4f} secs".format(epoch + 1, end - start)) logger.info("Validate...") loss, joint_acc, slot_acc = validate(model, reader, evaluator, config, local_rank) logger.info( "loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}". format(loss, joint_acc, slot_acc)) # logger.info("loss: {:.4f}, joint accuracy: {:.4f}, slot accuracy: {:.4f}, perplexity: {:.4f}, inform rate: {:.4f}, success rate: {:.4f}, inform test: {:.4f}, success test: {:.4f}"\ # .format(loss, joint_acc, slot_acc, perplexity, inform_rate, success_rate, inform_test, success_test)) if local_rank == 0: writer.add_scalar("Val/loss", loss, epoch + 1) writer.add_scalar("Val/joint_acc", joint_acc, epoch + 1) writer.add_scalar("Val/slot_acc", slot_acc, epoch + 1) if joint_acc > max_joint_acc: # save model if local_rank == 0: save(model, optimizer, save_path) logger.info("Saved to {}.".format(os.path.abspath(save_path))) max_joint_acc = joint_acc early_stop_count = config.early_stop_count else: # ealry stopping if early_stop_count == 0: logger.info("Early stopped.") break elif early_stop_count == 2: lr = lr / 2 logger.info("learning rate schedule: {}".format(lr)) for param in optimizer.param_groups: param["lr"] = lr early_stop_count -= 1 logger.info("early stop count: {}".format(early_stop_count)) logger.info("Training finished.")
inform_rate_ = inform_rate.float().mean(dim=1).sum(dim=0) inform_rate = (inform_rate.float().mean(dim=1) == 1) inform_rate = inform_rate.sum(dim=0).float() success_rate = (true_request_goals <= pred_request_goals ) # like recall (not precision) success_rate_ = success_rate.float().mean(dim=1).sum(dim=0) success_rate = (success_rate.float().mean(dim=1) == 1 ) # check all slots success_rate = success_rate.sum(dim=0).float() return inform_rate, success_rate, inform_rate_, success_rate_ if __name__ == "__main__": config = Config() parser = config.parser config = parser.parse_args() vocab = Vocab(config) vocab.load("save/vocab") reader = Reader(vocab, config) evaluator = MultiWOZEvaluator(reader, db, config) dial_id = evaluator.test_list[0] goal = evaluator.parse_goal(dial_id) print("")