def define_lr_scheduler(self): if self.learning_rate_scheduler == self.ft_learning_rate_scheduler: if self.learning_rate_scheduler == LR_LINEAR: self.lr_scheduler = lrs.LinearScheduler( self.optim, [self.warmup_init_lr, self.warmup_init_ft_lr], [self.num_warmup_steps, self.num_warmup_steps], self.num_steps) elif self.learning_rate_scheduler == LR_INVERSE_SQUARE: self.lr_scheduler = lrs.InverseSquareRootScheduler( self.optim, [self.warmup_init_lr, self.warmup_init_ft_lr], [self.num_warmup_steps, self.num_warmup_steps], self.num_steps) elif self.learning_rate_scheduler == LR_INVERSE_POWER: self.lr_scheduler = lrs.InversePowerScheduler( self.optim, 1.0, [self.warmup_init_lr, self.warmup_init_ft_lr], [self.num_warmup_steps, self.num_warmup_steps]) elif self.learning_rate_scheduler == LR_PLATEAU: self.lr_scheduler = lrs.ReduceLROnPlateau(self.optim, factor=0.5, patience=5, min_lr=1e-5, verbose=True) else: raise NotImplementedError else: assert (self.learning_rate_scheduler == LR_LINEAR and self.ft_learning_rate_scheduler == LR_INVERSE_SQUARE) self.lr_scheduler = lrs.HybridScheduler(self.optim, [ self.learning_rate_scheduler, self.ft_learning_rate_scheduler ], [self.warmup_init_lr, self.warmup_init_ft_lr ], [self.num_warmup_steps, self.num_warmup_steps], self.num_steps)
def train(train_data, dev_data): # Model model_dir = get_model_dir(args) if not os.path.exists(model_dir): os.mkdir(model_dir) trans_checker = TranslatabilityChecker(args) trans_checker.cuda() ops.initialize_module(trans_checker, 'xavier') wandb.init(project='translatability-prediction', name=get_wandb_tag(args)) wandb.watch(trans_checker) # Hyperparameters batch_size = 16 num_peek_epochs = 1 # Loss function # -100 is a dummy padding value since all output spans will be of length 2 loss_fun = MaskedCrossEntropyLoss(-100) # Optimizer optimizer = optim.Adam([{ 'params': [ p for n, p in trans_checker.named_parameters() if not 'trans_parameters' in n and p.requires_grad ] }, { 'params': [ p for n, p in trans_checker.named_parameters() if 'trans_parameters' in n and p.requires_grad ], 'lr': args.bert_finetune_rate }], lr=args.learning_rate) lr_scheduler = lrs.LinearScheduler( optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr], [args.num_warmup_steps, args.num_warmup_steps], args.num_steps) best_dev_metrics = 0 for epoch_id in range(args.num_epochs): random.shuffle(train_data) trans_checker.train() optimizer.zero_grad() epoch_losses = [] for i in tqdm(range(0, len(train_data), batch_size)): wandb.log({ 'learning_rate/{}'.format(args.dataset_name): optimizer.param_groups[0]['lr'] }) wandb.log({ 'fine_tuning_rate/{}'.format(args.dataset_name): optimizer.param_groups[1]['lr'] }) mini_batch = train_data[i:i + batch_size] _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id) encoder_input_ids = ops.pad_batch( [exp.ptr_input_ids for exp in mini_batch], bu.pad_id) target_span_ids, _ = ops.pad_batch( [exp.span_ids for exp in mini_batch], bu.pad_id) output = trans_checker(encoder_input_ids, text_masks) loss = loss_fun(output, target_span_ids) loss.backward() epoch_losses.append(float(loss)) if args.grad_norm > 0: nn.utils.clip_grad_norm_(trans_checker.parameters(), args.grad_norm) lr_scheduler.step() optimizer.step() optimizer.zero_grad() if args.num_epochs % num_peek_epochs == 0: stdout_msg = 'Epoch {}: average training loss = {}'.format( epoch_id, np.mean(epoch_losses)) print(stdout_msg) wandb.log({ 'cross_entropy_loss/{}'.format(args.dataset_name): np.mean(epoch_losses) }) pred_spans = trans_checker.inference(dev_data) target_spans = [exp.span_ids for exp in dev_data] trans_acc = translatablity_eval(pred_spans, target_spans) print('Dev translatability accuracy = {}'.format(trans_acc)) if trans_acc > best_dev_metrics: model_path = os.path.join(model_dir, 'model-best.tar') trans_checker.save_checkpoint(optimizer, lr_scheduler, model_path) best_dev_metrics = trans_acc span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans) print('Dev span accuracy = {}'.format(span_acc)) print('Dev span precision = {}'.format(prec)) print('Dev span recall = {}'.format(recall)) print('Dev span F1 = {}'.format(f1)) wandb.log({ 'translatability_accuracy/{}'.format(args.dataset_name): trans_acc }) wandb.log({'span_accuracy/{}'.format(args.dataset_name): span_acc}) wandb.log({'span_f1/{}'.format(args.dataset_name): f1})
def train(train_data, dev_data): # Model model_dir = get_model_dir(args) if not os.path.exists(model_dir): os.mkdir(model_dir) trans_checker = TranslatabilityChecker(args) trans_checker.cuda() ops.initialize_module(trans_checker, 'xavier') # Hyperparameters batch_size = min(len(train_data), 12) num_peek_epochs = 1 # Loss function loss_fun = nn.BCELoss() span_extract_pad_id = -100 span_extract_loss_fun = MaskedCrossEntropyLoss(span_extract_pad_id) # Optimizer optimizer = optim.Adam( [{'params': [p for n, p in trans_checker.named_parameters() if not 'trans_parameters' in n and p.requires_grad]}, {'params': [p for n, p in trans_checker.named_parameters() if 'trans_parameters' in n and p.requires_grad], 'lr': args.bert_finetune_rate}], lr=args.learning_rate) lr_scheduler = lrs.LinearScheduler( optimizer, [args.warmup_init_lr, args.warmup_init_ft_lr], [args.num_warmup_steps, args.num_warmup_steps], args.num_steps) best_dev_metrics = 0 for epoch_id in range(args.num_epochs): random.shuffle(train_data) trans_checker.train() optimizer.zero_grad() epoch_losses = [] for i in tqdm(range(0, len(train_data), batch_size)): mini_batch = train_data[i: i + batch_size] _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id) encoder_input_ids = ops.pad_batch([exp.ptr_input_ids for exp in mini_batch], bu.pad_id) target_ids = ops.int_var_cuda([1 if exp.span_ids[0] == 0 else 0 for exp in mini_batch]) target_span_ids, _ = ops.pad_batch([exp.span_ids for exp in mini_batch], bu.pad_id) target_span_ids = target_span_ids * (1 - target_ids.unsqueeze(1)) + \ target_ids.unsqueeze(1).expand_as(target_span_ids) * span_extract_pad_id output, span_extract_output = trans_checker(encoder_input_ids, text_masks) loss = loss_fun(output, target_ids.unsqueeze(1).float()) span_extract_loss = span_extract_loss_fun(span_extract_output, target_span_ids) loss += span_extract_loss loss.backward() epoch_losses.append(float(loss)) if args.grad_norm > 0: nn.utils.clip_grad_norm_(trans_checker.parameters(), args.grad_norm) lr_scheduler.step() optimizer.step() optimizer.zero_grad() with torch.no_grad(): if args.num_epochs % num_peek_epochs == 0: stdout_msg = 'Epoch {}: average training loss = {}'.format(epoch_id, np.mean(epoch_losses)) print(stdout_msg) pred_trans, pred_spans = trans_checker.inference(dev_data) targets = [1 if exp.span_ids[0] == 0 else 0 for exp in dev_data] target_spans = [exp.span_ids for exp in dev_data] trans_acc = translatablity_eval(pred_trans, targets) print('Dev translatability accuracy = {}'.format(trans_acc)) if trans_acc > best_dev_metrics: model_path = os.path.join(model_dir, 'model-best.tar') trans_checker.save_checkpoint(optimizer, lr_scheduler, model_path) best_dev_metrics = trans_acc span_acc, prec, recall, f1 = span_eval(pred_spans, target_spans) print('Dev span accuracy = {}'.format(span_acc)) print('Dev span precision = {}'.format(prec)) print('Dev span recall = {}'.format(recall)) print('Dev span F1 = {}'.format(f1))