def multi_main(args): """ Spawns 1 process per GPU """ init_logger() nb_gpu = args.world_size mp = torch.multiprocessing.get_context("spawn") # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for i in range(nb_gpu): device_id = i procs.append( mp.Process(target=run, args=( args, device_id, error_queue, ), daemon=True)) procs[i].start() logger.info(" Starting process pid: %d " % procs[i].pid) error_handler.add_child(procs[i].pid) for p in procs: p.join()
def train_single_ext(args, device_id): init_logger(args.log_file) device = "cpu" if args.visible_gpus == '-1' else "cuda" logger.info('Device ID %d' % device_id) logger.info('Device %s' % device) torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True if device_id >= 0: torch.cuda.set_device(device_id) torch.cuda.manual_seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True if args.train_from != '': logger.info('Loading checkpoint from %s' % args.train_from) checkpoint = torch.load(args.train_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) else: checkpoint = None def train_iter_fct(): return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device, shuffle=True, is_test=False) model = ExtSummarizer(args, device, checkpoint) optim = model_builder.build_optim(args, model, checkpoint) logger.info(model) trainer = build_trainer(args, device_id, model, optim) trainer.train(train_iter_fct, args.train_steps)
def train(args, device_id): init_logger(args.log_file) device = "cpu" if args.visible_gpus == "-1" else "cuda" logger.info("Device ID %d" % device_id) logger.info("Device %s" % device) torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True if device_id >= 0: torch.cuda.set_device(device_id) torch.cuda.manual_seed(args.seed) torch.manual_seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True def train_iter_fct(): return data_loader.Dataloader(args, load_dataset(args, "train", shuffle=True), args.batch_size, device, shuffle=True, is_test=False) model = Summarizer(args, device, load_pretrained_bert=True) if args.train_from != "": logger.info("Loading checkpoint from %s" % args.train_from) checkpoint = torch.load(args.train_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint["opt"]) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) model.load_cp(checkpoint) optim = model_builder.build_optim(args, model, checkpoint) else: optim = model_builder.build_optim(args, model, None) trainer = build_trainer(args, device_id, model, optim) trainer.train(train_iter_fct, args.train_steps)
# 处理json数据集名称,比如json_data/LCSTS.train.1.json,需要指定为LCSTS parser.add_argument('-dataset', default='LCSTS', type=str) # 模型输入训练,保存 parser.add_argument("-save_path", default='bert_data') ###change from 2000 to 16000 parser.add_argument("-shard_size", default=16000, type=int) # 最小句子量,文章不能低于3句话 parser.add_argument('-min_nsents', default=3, type=int) # 最大句子量,文章超过100句话 parser.add_argument('-max_nsents', default=100, type=int) # 句子最短长度 parser.add_argument('-min_src_ntokens', default=3, type=int) # 句子最大长度 parser.add_argument('-max_src_ntokens', default=150, type=int) parser.add_argument('-max_position_embeddings', default=512, type=int) parser.add_argument('-log_file', default='logs/preprocess.log') parser.add_argument('-n_cpus', default=4, type=int) bert_base_chinese = '/Users/jiang/Documents/bert/bert-base-chinese' parser.add_argument("-bert_base_chinese", type=str, default=bert_base_chinese) args = parser.parse_args() init_logger(args.log_file) data_builder_LAI.format_to_bert(args)