def train(args, train_dataset, val_dataset, model, tokenizer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=args.num_workers) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // (len(train_dataloader) // \ args.gradient_accumulation_steps) + 1 else: t_total = len(train_dataloader) // args.gradient_accumulation_steps \ * args.num_train_epochs # Prepare optimizer and scheduler no_decay = ['bias', 'LayerNorm.weight'] grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not \ any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if \ any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) if args.scheduler == "constant": scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps) elif args.scheduler == "linear": scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) else: raise ValueError("Unknown scheduler type: {}".format(args.scheduler)) if args.n_gpu > 1: model = torch.nn.DataParallel(model) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Total train batch size (w. parallel, & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) if args.scst: scst_criterion = ScstRewardCriterion() logger.info(" SCST training...") global_step, global_loss, global_acc = 0, 0.0, 0.0 model.zero_grad() eval_log = [] best_score = 0 for epoch in range(int(args.num_train_epochs)): for step, (img_keys, batch) in enumerate(train_dataloader): batch = tuple(t.to(args.device) for t in batch) if not args.scst: model.train() inputs = { 'input_ids': batch[0], 'attention_mask': batch[1], 'token_type_ids': batch[2], 'img_feats': batch[3], 'masked_pos': batch[4], 'masked_ids': batch[5] } outputs = model(**inputs) loss, logits = outputs[:2] masked_ids = inputs['masked_ids'] masked_ids = masked_ids[masked_ids != 0] batch_score = compute_score_with_logits(logits, masked_ids) batch_acc = torch.sum(batch_score.float()) / torch.sum( inputs['masked_pos']) else: loss = scst_train_iter(args, train_dataset, model, scst_criterion, img_keys, batch, tokenizer) batch_acc = scst_criterion.get_score() if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) global_loss += loss.item() global_acc += batch_acc if (step + 1) % args.gradient_accumulation_steps == 0: global_step += 1 scheduler.step() optimizer.step() model.zero_grad() if global_step % args.logging_steps == 0: logger.info("Epoch: {}, global_step: {}, lr: {:.6f}, loss: {:.4f} ({:.4f}), " \ "score: {:.4f} ({:.4f})".format(epoch, global_step, optimizer.param_groups[0]["lr"], loss, global_loss / global_step, batch_acc, global_acc / global_step) ) if (args.save_steps > 0 and global_step % args.save_steps == 0) or \ global_step == t_total: checkpoint_dir = save_checkpoint(model, tokenizer, args, epoch, global_step) # evaluation if args.evaluate_during_training: logger.info("Perform evaluation at step: %d" % (global_step)) evaluate_file = evaluate(args, val_dataset, model, tokenizer, checkpoint_dir) with open(evaluate_file, 'r') as f: res = json.load(f) best_score = max(best_score, res['CIDEr']) res['epoch'] = epoch res['global_step'] = step res['best_CIDEr'] = best_score eval_log.append(res) with open(args.output_dir + '/eval_logs.json', 'w') as f: json.dump(eval_log, f) return global_step, global_loss / global_step
def train(args, train_dataset, val_dataset, model, tokenizer): args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) train_sampler = RandomSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, num_workers=args.num_workers) if args.max_steps > 0: t_total = args.max_steps args.num_train_epochs = args.max_steps // (len(train_dataloader) // \ args.gradient_accumulation_steps) + 1 else: t_total = len(train_dataloader) // args.gradient_accumulation_steps \ * args.num_train_epochs # Prepare optimizer and scheduler no_decay = ['bias', 'LayerNorm.weight'] grouped_parameters = [ {'params': [p for n, p in model.named_parameters() if not \ any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {'params': [p for n, p in model.named_parameters() if \ any(nd in n for nd in no_decay)], 'weight_decay': 0.0} ] optimizer = AdamW(grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) if args.scheduler == "constant": scheduler = WarmupConstantSchedule(optimizer, warmup_steps=args.warmup_steps) elif args.scheduler == "linear": scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total) else: raise ValueError("Unknown scheduler type: {}".format(args.scheduler)) if args.n_gpu > 1: model = torch.nn.DataParallel(model) logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info(" Total train batch size (w. parallel, & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step, global_loss, global_acc = 0, 0.0, 0.0 model.zero_grad() log_json = [] best_score = 0 for epoch in range(int(args.num_train_epochs)): for step, (_, batch) in enumerate(train_dataloader): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = { 'input_ids': torch.cat((batch[0], batch[5]), dim=0), 'attention_mask': torch.cat((batch[1], batch[6]), dim=0), 'token_type_ids': torch.cat((batch[2], batch[7]), dim=0), 'img_feats': torch.cat((batch[3], batch[8]), dim=0), 'labels': torch.cat((batch[4], batch[9]), dim=0) } outputs = model(**inputs) loss, logits = outputs[:2] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) batch_score = compute_score_with_logits(logits, inputs['labels']).sum() batch_acc = batch_score.item() / (args.train_batch_size * 2) global_loss += loss.item() global_acc += batch_acc if (step + 1) % args.gradient_accumulation_steps == 0: global_step += 1 scheduler.step() optimizer.step() model.zero_grad() if global_step % args.logging_steps == 0: logger.info("Epoch: {}, global_step: {}, lr: {:.6f}, loss: {:.4f} ({:.4f}), " \ "score: {:.4f} ({:.4f})".format(epoch, global_step, optimizer.param_groups[0]["lr"], loss, global_loss / global_step, batch_acc, global_acc / global_step) ) if (args.save_steps > 0 and global_step % args.save_steps == 0) or \ global_step == t_total: save_checkpoint(model, tokenizer, args, epoch, global_step) # evaluation if args.evaluate_during_training: logger.info("Perform evaluation at step: %d" % (global_step)) test_result = test(args, model, val_dataset) eval_result = evaluate(val_dataset, test_result) rank_accs = eval_result['i2t_retrieval'] if rank_accs['R@1'] > best_score: best_score = rank_accs['R@1'] epoch_log = { 'epoch': epoch, 'global_step': global_step, 'R1': rank_accs['R@1'], 'R5': rank_accs['R@5'], 'R10': rank_accs['R@10'], 'best_R1': best_score } log_json.append(epoch_log) with open(args.output_dir + '/eval_logs.json', 'w') as fp: json.dump(log_json, fp) return global_step, global_loss / global_step