def load_dataset_inputs(prefix): input_f = os.path.join(args.input_dir, '{}_inputs.h5'.format(prefix)) return input_f def main(): config = json.load(open(args.config_file)) assert config['config_target'] == 'naive_psychology' set_seed( args.gpu_id, args.seed ) # in distributed training, this has to be same for all processes if args.mode == 'train': train_dinputs = load_dataset_inputs('train') dev_dinputs = load_dataset_inputs('dev') train(train_dinputs, dev_dinputs) # test test_dinputs = load_dataset_inputs('test') test(test_dinputs, args.output_dir) else: # test test_dinputs = load_dataset_inputs('test') test(test_dinputs, args.from_checkpoint) if __name__ == "__main__": args = utils.bin_config(get_arguments) logger = utils.get_root_logger(args) main()
def train(rank, train_inputs, dev_inputs, config, args): logger = utils.get_root_logger(args, log_fname='log_rank{}'.format(rank)) if args.n_gpus > 1: local_rank = rank args.gpu_id = rank else: local_rank = -1 if args.n_gpus > 1: dist.init_process_group(backend='nccl', init_method='env://', world_size=args.n_gpus, rank=local_rank) set_seed( args.gpu_id, args.seed ) # in distributed training, this has to be same for all processes logger.info('local_rank = {}, n_gpus = {}'.format(local_rank, args.n_gpus)) logger.info('n_epochs = {}'.format(args.n_epochs)) if args.gpu_id != -1: torch.cuda.set_device(args.gpu_id) rtype_distr = config['rtype_distr'] rtype2idx = config['rtype2idx'] rtype_distr = {rtype2idx[r]: d for r, d in rtype_distr.items()} # dev dataset if local_rank in [-1, 0]: dev_dataset = PNGPretrainDataset(dev_inputs[0], dev_inputs[1], args.task, is_train=False) dev_dataloader = DataLoader( dev_dataset, batch_size=args.eval_batch_size, # fix 1 for sampling shuffle=False, collate_fn=my_dev_collate, num_workers=1) # 1 is safe for hdf5 # train dataset train_dataset = PNGPretrainDataset(train_inputs[0], train_inputs[1], args.task, is_train=True) if args.n_gpus > 1: train_sampler = DistributedSampler(train_dataset, num_replicas=args.n_gpus, rank=local_rank) shuffle = False else: train_sampler = None shuffle = True train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size, # fix 1 for sampling shuffle=shuffle, collate_fn=my_train_collate, num_workers=1, sampler=train_sampler) # 1 is safe for hdf5 # model # pos_weight = torch.FloatTensor([args.pos_weight]) model = get_model(args.from_checkpoint, args.weight_name, args.freeze_lm) criterion = model.criterion if args.gpu_id != -1: model = model.cuda(args.gpu_id) # pos_weight = pos_weight.cuda(args.gpu_id) if args.n_gpus > 1: model = DistributedDataParallel(model, device_ids=[args.gpu_id], find_unused_parameters=True) optimizer = utils.get_optimizer_adam(model, args.weight_decay, args.lr, args.adam_epsilon, args.from_checkpoint, (args.adam_beta1, args.adam_beta2)) # optimizer = utils.get_optimizer_adamw( # model, args.weight_decay, args.lr, # args.adam_epsilon, args.from_checkpoint # ) n = len(train_dataset) scheduler = utils.get_scheduler(n, optimizer, args.train_batch_size, args.gradient_accumulation_steps, args.n_epochs, args.warmup_steps, args.warmup_portion, args.from_checkpoint) if local_rank in [-1, 0]: logger.info("***** Running training *****") logger.info(" Num Epochs = %d", args.n_epochs) logger.info(" Training batch size = %d", args.train_batch_size) logger.info(" Evaluation batch size = %d", args.eval_batch_size) logger.info(" Accu. train batch size = %d", args.train_batch_size * args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Weight Decay = {}".format(args.weight_decay)) logger.info(" Learning Rate = {}".format(args.lr)) if args.no_first_eval or args.no_eval: best_metric = 0.0 else: best_metric = evaluate(model, dev_dataloader, args.gpu_id, logger=logger) logger.info('start dev_metric = {}'.format(best_metric)) tb_writer = SummaryWriter('{}/explog'.format(args.output_dir)) else: best_metric = 0.0 t1 = time.time() step, accu_step = 0, 0 prev_acc_loss, acc_loss = 0.0, 0.0 model.zero_grad() for i_epoch in range(args.n_epochs): t2 = time.time() logger.info('========== Epoch {} =========='.format(i_epoch)) for batch in train_dataloader: batch = batch_sample_truncated_graphs(batch, rtype_distr, args) if batch is None: # sometimes happen logger.warning('unable to sample neg edges, skip this batch') continue batch = calculate_norms(batch) model.train() # to GPU if args.gpu_id != -1: batch = to_gpu(batch, args.gpu_id) # forward pass all_scores, all_ys, all_rtypes, all_embs = model(**batch) loss = criterion(all_scores, all_ys, all_rtypes, all_embs) if args.n_gpus > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps # backward pass loss.backward() # accumulation acc_loss += loss.item() accu_step += 1 if accu_step % args.gradient_accumulation_steps == 0: # ignore the last accumulation # update params torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() model.zero_grad() step += 1 # loss if local_rank in [-1, 0]: if args.logging_steps > 0 and step % args.logging_steps == 0: cur_loss = (acc_loss - prev_acc_loss) / args.logging_steps logger.info( 'task={}, train_loss={}, accu_step={}, step={}, time={}s' .format(args.task, cur_loss, accu_step, step, time.time() - t1)) tb_writer.add_scalar('train_loss', cur_loss, step) tb_writer.add_scalar('lr', scheduler.get_last_lr()[0]) # evaluate if not args.no_eval: dev_metric = evaluate(model, dev_dataloader, args.gpu_id, logger=logger) logger.info('dev_metric={}'.format(dev_metric)) if best_metric < dev_metric: best_metric = dev_metric # save utils.save_model(model, optimizer, scheduler, args.output_dir, step) prev_acc_loss = acc_loss logger.info('done epoch {}: {} s'.format(i_epoch, time.time() - t2)) if local_rank in [-1, 0] and args.no_eval: logger.info('saving model for epoch {}'.format(i_epoch)) utils.save_model(model, optimizer, scheduler, args.output_dir, step) if local_rank in [-1, 0]: tb_writer.close() logger.info('best_dev_metric = {}'.format(best_metric)) logger.info('done training: {} s'.format(time.time() - t1))