示例#1
0
def multi_main(args):
    """ Spawns 1 process per GPU """
    init_logger()

    nb_gpu = args.world_size
    mp = torch.multiprocessing.get_context("spawn")

    # Create a thread to listen for errors in the child processes.
    error_queue = mp.SimpleQueue()
    error_handler = ErrorHandler(error_queue)

    # Train with multiprocessing.
    procs = []
    for i in range(nb_gpu):
        device_id = i
        procs.append(
            mp.Process(target=run,
                       args=(
                           args,
                           device_id,
                           error_queue,
                       ),
                       daemon=True))
        procs[i].start()
        logger.info(" Starting process pid: %d  " % procs[i].pid)
        error_handler.add_child(procs[i].pid)
    for p in procs:
        p.join()
示例#2
0
def train_single_ext(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, 'train',
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    model = ExtSummarizer(args, device, checkpoint)
    optim = model_builder.build_optim(args, model, checkpoint)

    logger.info(model)

    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
示例#3
0
def train(args, device_id):
    init_logger(args.log_file)

    device = "cpu" if args.visible_gpus == "-1" else "cuda"
    logger.info("Device ID %d" % device_id)
    logger.info("Device %s" % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(args,
                                      load_dataset(args, "train",
                                                   shuffle=True),
                                      args.batch_size,
                                      device,
                                      shuffle=True,
                                      is_test=False)

    model = Summarizer(args, device, load_pretrained_bert=True)
    if args.train_from != "":
        logger.info("Loading checkpoint from %s" % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint["opt"])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
        model.load_cp(checkpoint)
        optim = model_builder.build_optim(args, model, checkpoint)
    else:
        optim = model_builder.build_optim(args, model, None)

    trainer = build_trainer(args, device_id, model, optim)
    trainer.train(train_iter_fct, args.train_steps)
示例#4
0
    # 处理json数据集名称,比如json_data/LCSTS.train.1.json,需要指定为LCSTS
    parser.add_argument('-dataset', default='LCSTS', type=str)
    # 模型输入训练,保存
    parser.add_argument("-save_path", default='bert_data')

    ###change from 2000 to 16000
    parser.add_argument("-shard_size", default=16000, type=int)
    # 最小句子量,文章不能低于3句话
    parser.add_argument('-min_nsents', default=3, type=int)
    # 最大句子量,文章超过100句话
    parser.add_argument('-max_nsents', default=100, type=int)

    # 句子最短长度
    parser.add_argument('-min_src_ntokens', default=3, type=int)
    # 句子最大长度
    parser.add_argument('-max_src_ntokens', default=150, type=int)

    parser.add_argument('-max_position_embeddings', default=512, type=int)
    parser.add_argument('-log_file', default='logs/preprocess.log')

    parser.add_argument('-n_cpus', default=4, type=int)

    bert_base_chinese = '/Users/jiang/Documents/bert/bert-base-chinese'
    parser.add_argument("-bert_base_chinese",
                        type=str,
                        default=bert_base_chinese)

    args = parser.parse_args()
    init_logger(args.log_file)
    data_builder_LAI.format_to_bert(args)