def train_base_model( model: NeuralModelBase, dataset: Dataset, num_epochs, train_iter, valid_iter, lr=0.001, verbose=True, ): valid_iters = [valid_iter ] if not isinstance(valid_iter, list) else valid_iter Logger.start_scope("Training Model") opt = optim.Adam(model.parameters(), lr=lr) model.opt = opt loss_function = nn.CrossEntropyLoss(reduction="none") model.loss_function = loss_function train_prec, valid_prec = None, None for epoch in range(num_epochs): Logger.start_scope("Epoch {}".format(epoch)) model.fit(train_iter, opt, loss_function, mask_field="mask_valid") for valid_iter in valid_iters: valid_stats = model.accuracy(valid_iter, dataset.TARGET, verbose=verbose) valid_prec = valid_stats["mask_valid_noreject_acc"] Logger.debug(f"valid_prec: {valid_prec}") Logger.end_scope() train_stats = model.accuracy(train_iter, dataset.TARGET, verbose=False) train_prec = train_stats["mask_valid_noreject_acc"] Logger.debug(f"train_prec: {train_prec}, valid_prec: {valid_prec}") Logger.end_scope() return train_prec, valid_prec
def train_model( model: NeuralModelBase, dataset: Dataset, num_epochs, train_iter, valid_iter, lr=0.001, weight=None, target_o=1.0, ): # model.reset_parameters() opt = optim.Adam(model.parameters(), lr=lr) Logger.start_scope("Training Model") o_base = len(dataset.TARGET.vocab) - 4 # 'reject', '<unk>', '<pad>' loss_function = RejectionCrossEntropyLoss( o_base, len(dataset.TARGET.vocab), dataset.reject_token_id, reduction="none", weight=weight, ) model.loss_function = loss_function model.opt = opt step = 1.0 / (num_epochs // 2) schedule = [ f * o_base + (1 - f) * 1.0 for f in np.arange(start=1.0, stop=0.0, step=-step) ] schedule += [ f * ((1.0 + schedule[-1]) / 2) + (1 - f) * target_o for f in np.arange(start=1.0, stop=0.0, step=-step) ] schedule += [target_o] * (num_epochs // 2) train_prec, valid_prec = None, None for epoch, o_upper in enumerate(schedule): Logger.start_scope("Epoch {}, o_upper={:.3f}".format(epoch, o_upper)) loss_function.o = o_upper model.fit(train_iter, opt, loss_function, mask_field="mask_valid") valid_stats = model.accuracy( valid_iter, dataset.TARGET ) # , thresholds=[0.5, 0.8, 0.9, 0.95]) valid_prec = valid_stats["mask_valid_noreject_acc"] Logger.debug(f"valid_prec: {valid_prec}") Logger.end_scope() # Logger.start_scope('Print Rejection Thresholds') # print_rejection_thresholds(train_iter, model, dataset) # print_rejection_thresholds(valid_iter, model, dataset) # Logger.end_scope() # Logger.start_scope('Get Rejection Thresholds') # get_rejection_thresholds(train_iter, model, dataset, [1.00, 0.99, 0.95, 0.9, 0.8]) # get_rejection_thresholds(valid_iter, model, dataset, [1.00, 0.99, 0.95, 0.9, 0.8]) # Logger.end_scope() train_stats = model.accuracy(train_iter, dataset.TARGET, verbose=False) train_prec = train_stats["mask_valid_noreject_acc"] Logger.debug(f"train_prec: {train_prec}, valid_prec: {valid_prec}") Logger.end_scope() # exit(0) return train_prec, valid_prec