コード例 #1
0
def train():
    torch.manual_seed(1)
    if (config.cuda):
        torch.cuda.manual_seed(1)
    args = dict()
    args['embed_size'] = config.embed_size
    args['d_model'] = config.d_model
    args['nhead'] = config.nhead
    args['num_encoder_layers'] = config.num_encoder_layers
    args['num_decoder_layers'] = config.num_decoder_layers
    args['dim_feedforward'] = config.dim_feedforward
    args['dropout'] = config.dropout
    args['smoothing_eps'] = config.smoothing_eps
    text = Text(config.src_corpus, config.tar_corpus)
    train_data = Data(config.train_path_src, config.train_path_tar)
    dev_data = Data(config.dev_path_src, config.dev_path_tar)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.train_batch_size,
                              shuffle=True,
                              collate_fn=utils.get_batch)
    dev_loader = DataLoader(dataset=dev_data,
                            batch_size=config.dev_batch_size,
                            shuffle=True,
                            collate_fn=utils.get_batch)
    #train_data_src, train_data_tar = utils.read_corpus(config.train_path)
    #dev_data_src, dev_data_tar = utils.read_corpus(config.dev_path)
    device = torch.device("cuda:0" if config.cuda else "cpu")
    model = NMT(text, args, device)
    #model = nn.DataParallel(model, device_ids=[0, 1])
    model = model.to(device)
    #model = model.module
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_transformer/result/02.01_1_344.6820465077113_checkpoint.pth"
    #model = NMT.load(model_path)
    #model = model.to(device)
    model.train()
    optimizer = Optim(
        torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9),
        config.d_model, config.warm_up_step)
    #optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.warm_up_step, config.init_lr, config.lr)
    #optimizer = Optim(torch.optim.Adam(model.parameters()))

    epoch = 0
    history_valid_ppl = []
    print("begin training!", file=sys.stderr)
    while (True):
        epoch += 1
        max_iter = int(math.ceil(len(train_data) / config.train_batch_size))
        with tqdm(total=max_iter, desc="train") as pbar:
            #for batch_src, batch_tar, tar_word_num in utils.batch_iter(train_data_src, train_data_tar, config.train_batch_size):
            for batch_src, batch_tar, tar_word_num in train_loader:
                optimizer.zero_grad()
                now_batch_size = len(batch_src)
                batch_loss = -model(batch_src, batch_tar, smoothing=True)
                batch_loss = batch_loss.sum()
                loss = batch_loss / now_batch_size
                loss.backward()
                #optimizer.step()
                #optimizer.updata_lr()
                optimizer.step_and_updata_lr()
                pbar.set_postfix({
                    "epoch":
                    epoch,
                    "avg_loss":
                    '{%.2f}' % (loss.item()),
                    "ppl":
                    '{%.2f}' % (math.exp(batch_loss.item() / tar_word_num))
                })
                pbar.update(1)
        if (epoch % config.valid_iter == 0):
            print("now begin validation...", file=sys.stderr)
            eval_ppl = evaluate_ppl(model, dev_data, dev_loader,
                                    config.dev_batch_size)
            print(eval_ppl)
            flag = len(
                history_valid_ppl) == 0 or eval_ppl < min(history_valid_ppl)
            if (flag):
                print(
                    f"current model is the best! save to [{config.model_save_path}]",
                    file=sys.stderr)
                history_valid_ppl.append(eval_ppl)
                model.save(
                    os.path.join(config.model_save_path,
                                 f"02.10_{epoch}_{eval_ppl}_checkpoint.pth"))
                torch.save(
                    optimizer.optimizer.state_dict(),
                    os.path.join(config.model_save_path,
                                 f"02.10_{epoch}_{eval_ppl}_optimizer.optim"))
        if (epoch == config.max_epoch):
            print("reach the maximum number of epochs!", file=sys.stderr)
            return
コード例 #2
0
def train():
    text = Text(config.src_corpus, config.tar_corpus)
    train_data = Data(config.train_path_src, config.train_path_tar)
    dev_data = Data(config.dev_path_src, config.dev_path_tar)
    train_loader = DataLoader(dataset=train_data,
                              batch_size=config.batch_size,
                              shuffle=True,
                              collate_fn=utils.get_batch)
    dev_loader = DataLoader(dataset=dev_data,
                            batch_size=config.dev_batch_size,
                            shuffle=True,
                            collate_fn=utils.get_batch)
    parser = OptionParser()
    parser.add_option("--embed_size",
                      dest="embed_size",
                      default=config.embed_size)
    parser.add_option("--hidden_size",
                      dest="hidden_size",
                      default=config.hidden_size)
    parser.add_option("--window_size_d",
                      dest="window_size_d",
                      default=config.window_size_d)
    parser.add_option("--encoder_layer",
                      dest="encoder_layer",
                      default=config.encoder_layer)
    parser.add_option("--decoder_layers",
                      dest="decoder_layers",
                      default=config.decoder_layers)
    parser.add_option("--dropout_rate",
                      dest="dropout_rate",
                      default=config.dropout_rate)
    (options, args) = parser.parse_args()
    device = torch.device("cuda:0" if config.cuda else "cpu")
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/01.31_drop0.3_54_21.46508598886769_checkpoint.pth"
    #print(f"load model from {model_path}", file=sys.stderr)
    #model = NMT.load(model_path)
    model = NMT(text, options, device)
    #model = model.cuda()
    #model_path = "/home/wangshuhe/shuhelearn/ShuHeLearning/NMT_attention/result/140_164.29781984744628_checkpoint.pth"
    #print(f"load model from {model_path}", file=sys.stderr)
    #model = NMT.load(model_path)
    #model = torch.nn.DataParallel(model)
    model = model.to(device)
    model = model.cuda()
    model.train()
    optimizer = Optim(torch.optim.Adam(model.parameters()))
    #optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.hidden_size, config.warm_up_step)
    #print(optimizer.lr)
    epoch = 0
    valid_num = 1
    hist_valid_ppl = []

    print("begin training!")
    while (True):
        epoch += 1
        max_iter = int(math.ceil(len(train_data) / config.batch_size))
        with tqdm(total=max_iter, desc="train") as pbar:
            for src_sents, tar_sents, tar_words_num_to_predict in train_loader:
                optimizer.zero_grad()
                batch_size = len(src_sents)

                now_loss = -model(src_sents, tar_sents)
                now_loss = now_loss.sum()
                loss = now_loss / batch_size
                loss.backward()

                _ = torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   config.clip_grad)
                #optimizer.updata_lr()
                optimizer.step_and_updata_lr()

                pbar.set_postfix({
                    "epwwoch":
                    epoch,
                    "avg_loss":
                    loss.item(),
                    "ppl":
                    math.exp(now_loss.item() / tar_words_num_to_predict),
                    "lr":
                    optimizer.lr
                })
                #pbar.set_postfix({"epoch": epoch, "avg_loss": loss.item(), "ppl": math.exp(now_loss.item()/tar_words_num_to_predict)})
                pbar.update(1)
        #print(optimizer.lr)
        if (epoch % config.valid_iter == 0):
            #if (epoch >= config.valid_iter//2):
            if (valid_num % 5 == 0):
                valid_num = 0
                optimizer.updata_lr()
            valid_num += 1
            print("now begin validation ...", file=sys.stderr)
            eav_ppl = evaluate_ppl(model, dev_data, dev_loader)
            print("validation ppl %.2f" % (eav_ppl), file=sys.stderr)
            flag = len(hist_valid_ppl) == 0 or eav_ppl < min(hist_valid_ppl)
            if (flag):
                print("current model is the best!, save to [%s]" %
                      (config.model_save_path),
                      file=sys.stderr)
                hist_valid_ppl.append(eav_ppl)
                model.save(
                    os.path.join(
                        config.model_save_path,
                        f"02.08_window35drop0.2_{epoch}_{eav_ppl}_checkpoint.pth"
                    ))
                torch.save(
                    optimizer.optimizer.state_dict(),
                    os.path.join(
                        config.model_save_path,
                        f"02.08_window35drop0.2_{epoch}_{eav_ppl}_optimizer.optim"
                    ))
        if (epoch == config.max_epoch):
            print("reach the maximum number of epochs!", file=sys.stderr)
            return
コード例 #3
0
ファイル: train.py プロジェクト: Crazy-Chick/ShuHeLearning
def train(index):
    torch.manual_seed(1)
    if (config.cuda):
        torch.cuda.manual_seed(1)
    device = torch.device(f"cuda:{index}" if config.cuda else "cpu")
    dist_rank = index
    torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:23456', rank=dist_rank, world_size=1)
    is_master_node = (dist_rank == 0)
    
    args = dict()
    args['embed_size'] = config.embed_size
    args['d_model'] = config.d_model
    args['nhead'] = config.nhead
    args['num_encoder_layers'] = config.num_encoder_layers
    args['num_decoder_layers'] = config.num_decoder_layers
    args['dim_feedforward'] = config.dim_feedforward
    args['dropout'] = config.dropout
    args['smoothing_eps'] = config.smoothing_eps
    
    text = Text(config.src_corpus, config.tar_corpus)
    model = NMT(text, args, device)
    model = make_data_parallel(model, device)
    
    train_data = Data(config.train_path_src, config.train_path_tar)
    dev_data = Data(config.dev_path_src, config.dev_path_tar)
    train_sampler = DistributedSampler(train_data)
    dev_sampler = DistributedSampler(dev_data)
    train_loader = DataLoader(dataset=train_data, batch_size=int(config.train_batch_size/8), shuffle=False, num_workers=9, pin_memory=True, sampler=train_sampler, collate_fn=utils.get_batch)
    dev_loader = DataLoader(dataset=dev_data, batch_size=int(config.dev_batch_size/8), shuffle=False, num_workers=9, pin_memory=True, sampler=dev_sampler, collate_fn=utils.get_batch)

    model.train()
    optimizer = Optim(torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), config.d_model, config.warm_up_step)

    epoch = 0
    history_valid_ppl = []
    print("begin training!", file=sys.stderr)
    while (True):
        epoch += 1
        train_loader.sampler.set_epoch(epoch)
        max_iter = int(math.ceil(len(train_data)/config.train_batch_size))
        with tqdm(total=max_iter, desc="train") as pbar:
            for batch_src, batch_tar, tar_word_num in train_loader:
                optimizer.zero_grad()
                now_batch_size = len(batch_src)
                batch_loss = -model(batch_src, batch_tar, smoothing=True)
                batch_loss = batch_loss.sum()
                loss = batch_loss / now_batch_size
                loss.backward()
                torch.distributed.barrier()
                optimizer.step_and_updata_lr()
                if (is_master_node):
                    pbar.set_postfix({"epoch": epoch, "avg_loss": '{%.2f}' % (loss.item()), "ppl": '{%.2f}' % (batch_loss.item()/tar_word_num)})
                    pbar.update(1)
        if (epoch % config.valid_iter == 0):
            print("now begin validation...", file=sys.stderr)
            torch.distributed.barrier()
            eval_ppl = evaluate_ppl(model, dev_data, dev_loader, config.dev_batch_size, is_master_node)
            print(eval_ppl)
            flag = len(history_valid_ppl) == 0 or eval_ppl < min(history_valid_ppl)
            if (flag):
                print(f"current model is the best! save to [{config.model_save_path}]", file=sys.stderr)
                history_valid_ppl.append(eval_ppl)
                model.save(os.path.join(config.model_save_path, f"02.19_{epoch}_{eval_ppl}_checkpoint.pth"))
                torch.save(optimizer.optimizer.state_dict(), os.path.join(config.model_save_path, f"02.19_{epoch}_{eval_ppl}_optimizer.optim"))
        if (epoch == config.max_epoch):
            print("reach the maximum number of epochs!", file=sys.stderr)
            return