def train(args,model,processor): train_dataset = load_and_cache_examples(args, processor, data_type='train') train_loader = DatasetLoader(data=train_dataset, batch_size=args.batch_size, shuffle=False, seed=args.seed, sort=True, vocab = processor.vocab,label2id = args.label2id) parameters = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(parameters, lr=args.learning_rate) scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=1, epsilon=1e-4, cooldown=0, min_lr=0, eps=1e-8) best_f1 = 0 for epoch in range(1, 1 + args.epochs): print(f"Epoch {epoch}/{args.epochs}") pbar = ProgressBar(n_total=len(train_loader), desc='Training') train_loss = AverageMeter() model.train() assert model.training for step, batch in enumerate(train_loader): input_ids, input_mask, input_tags, input_lens = batch input_ids = input_ids.to(args.device) input_mask = input_mask.to(args.device) input_tags = input_tags.to(args.device) features, loss = model.forward_loss(input_ids, input_mask, input_lens, input_tags) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) optimizer.step() optimizer.zero_grad() pbar(step=step, info={'loss': loss.item()}) train_loss.update(loss.item(), n=1) print(" ") train_log = {'loss': train_loss.avg} if 'cuda' in str(args.device): torch.cuda.empty_cache() eval_log, class_info = evaluate(args,model,processor) logs = dict(train_log, **eval_log) show_info = f'\nEpoch: {epoch} - ' + "-".join([f' {key}: {value:.4f} ' for key, value in logs.items()]) logger.info(show_info) scheduler.epoch_step(logs['eval_f1'], epoch) if logs['eval_f1'] > best_f1: logger.info(f"\nEpoch {epoch}: eval_f1 improved from {best_f1} to {logs['eval_f1']}") logger.info("save model to disk.") best_f1 = logs['eval_f1'] if isinstance(model, nn.DataParallel): model_stat_dict = model.module.state_dict() else: model_stat_dict = model.state_dict() state = {'epoch': epoch, 'arch': args.arch, 'state_dict': model_stat_dict} model_path = args.output_dir / 'best-model.bin' torch.save(state, str(model_path)) print("Eval Entity Score: ") for key, value in class_info.items(): info = f"Subject: {key} - Acc: {value['acc']} - Recall: {value['recall']} - F1: {value['f1']}" logger.info(info)
def initialize(mode, is_gpu, dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker): if 'TORCHVISION_MEMORY' == mode: trainloader, testloader, li_class = make_dataloader_torchvison_memory( dir_data, di_set_transform, n_img_per_batch, n_worker) elif 'TORCHVISION_IMAGEFOLDER' == mode: trainloader, testloader, li_class = make_dataloader_torchvison_imagefolder( dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker) elif 'CUSTOM_MEMORY' == mode: trainloader, testloader, li_class = make_dataloader_custom_memory( dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker) elif 'CUSTOM_FILE' == mode: trainloader, testloader, li_class = make_dataloader_custom_file( dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker) else: trainloader, testloader, li_class = make_dataloader_custom_tensordataset( dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker) #net = Net().cuda() net = Net() #t1 = net.cuda() criterion = nn.CrossEntropyLoss() if is_gpu: net.cuda() criterion.cuda() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=1, patience=8, epsilon=0.00001, min_lr=0.000001) # set up scheduler return trainloader, testloader, net, criterion, optimizer, scheduler, li_class
def initialize( dev, dir_data, size_img, #di_set_transform, ext_img, n_img_per_batch, n_worker, li_idx_sample_ratio=None #, n_class ): trainloader, testloader, li_idx_sample, li_fn_sample =\ make_dataloader_custom_file( dir_data, size_img, #di_set_transform, ext_img, n_img_per_batch, n_worker, li_idx_sample_ratio) #net = Net().cuda() #net = Net(n_class, n_img_per_batch) #net = Net(n_img_per_batch) net = Network((64, 64)) #t1 = net.cuda() #criterion = nn.CrossEntropyLoss() criterion = nn.MSELoss() #print('is_gpu :', is_gpu); exit(0); #if is_gpu: # net.cuda() # criterion.cuda() #print(net.li_conv_block[0].layer_in_a_row[0].weight.type()) #print(net) #print(net.conv_block_series[0].layer_in_a_row[0].weight.type()); #exit(0); net = net.to(dev) criterion.to(dev) #print(net); exit(0); #print(net.conv_block_series[0].layer_in_a_row[0].weight.type()); exit(0); #optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) #optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) optimizer = optim.Adam(net.parameters()) scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=1, patience=8, epsilon=0.00001, min_lr=0.000001) # set up scheduler return trainloader, testloader, net, criterion, optimizer, scheduler, li_idx_sample, li_fn_sample
def initialize(is_gpu, dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker): trainloader, testloader, li_class = make_dataloader_custom_file( dir_data, di_set_transform, ext_img, n_img_per_batch, n_worker) #net = Net().cuda() net = Net_gap() #t1 = net.cuda() criterion = nn.CrossEntropyLoss() if is_gpu: net.cuda() criterion.cuda() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=1, patience=8, epsilon=0.00001, min_lr=0.000001) # set up scheduler return trainloader, testloader, net, criterion, optimizer, scheduler, li_class
def train(): # Check NNabla version if utils.get_nnabla_version_integer() < 11900: raise ValueError( 'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0' ) parser, args = get_train_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) ext = import_extension_module(args.context) # Monitors # setting up monitors for logging monitor_path = args.output monitor = Monitor(monitor_path) monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1) monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1) monitor_validation_loss = MonitorSeries('Validation loss', monitor, interval=1) monitor_lr = MonitorSeries('learning rate', monitor, interval=1) monitor_time = MonitorTimeElapsed("training time per iteration", monitor, interval=1) if comm.rank == 0: print("Mixing coef. is {}, i.e., MDL = {}*TD-Loss + FD-Loss".format( args.mcoef, args.mcoef)) if not os.path.isdir(args.output): os.makedirs(args.output) # Initialize DataIterator for MUSDB. train_source, valid_source, args = load_datasources(parser, args) train_iter = data_iterator(train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) valid_iter = data_iterator(valid_source, 1, RandomState(args.seed), with_memory_cache=False, with_file_cache=False) if comm.n_procs > 1: train_iter = train_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) valid_iter = valid_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) # Calculate maxiter per GPU device. max_iter = int((train_source._size // args.batch_size) // comm.n_procs) weight_decay = args.weight_decay * comm.n_procs print("max_iter", max_iter) # Calculate the statistics (mean and variance) of the dataset scaler_mean, scaler_std = utils.get_statistics(args, train_source) max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft, args.bandwidth) unmix = OpenUnmix_CrossNet(input_mean=scaler_mean, input_scale=scaler_std, nb_channels=args.nb_channels, hidden_size=args.hidden_size, n_fft=args.nfft, n_hop=args.nhop, max_bin=max_bin) # Create input variables. mixture_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[0].shape)) target_audio = nn.Variable([args.batch_size] + list(train_source._get_data(0)[1].shape)) vmixture_audio = nn.Variable( [1] + [2, valid_source.sample_rate * args.valid_dur]) vtarget_audio = nn.Variable([1] + [8, valid_source.sample_rate * args.valid_dur]) # create training graph mix_spec, M_hat, pred = unmix(mixture_audio) Y = Spectrogram(*STFT(target_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) loss_f = mse_loss(mix_spec, M_hat, Y) loss_t = sdr_loss(mixture_audio, pred, target_audio) loss = args.mcoef * loss_t + loss_f loss.persistent = True # Create Solver and set parameters. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # create validation graph vmix_spec, vM_hat, vpred = unmix(vmixture_audio, test=True) vY = Spectrogram(*STFT(vtarget_audio, n_fft=unmix.n_fft, n_hop=unmix.n_hop), mono=(unmix.nb_channels == 1)) vloss_f = mse_loss(vmix_spec, vM_hat, vY) vloss_t = sdr_loss(vmixture_audio, vpred, vtarget_audio) vloss = args.mcoef * vloss_t + vloss_f vloss.persistent = True # Initialize Early Stopping es = utils.EarlyStopping(patience=args.patience) # Initialize LR Scheduler (ReduceLROnPlateau) lr_scheduler = ReduceLROnPlateau(lr=args.lr, factor=args.lr_decay_gamma, patience=args.lr_decay_patience) best_epoch = 0 # Training loop. for epoch in trange(args.epochs): # TRAINING losses = utils.AverageMeter() for batch in range(max_iter): mixture_audio.d, target_audio.d = train_iter.next() solver.zero_grad() loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() losses.update(loss.d.copy(), args.batch_size) training_loss = losses.avg # clear cache memory ext.clear_memory_cache() # VALIDATION vlosses = utils.AverageMeter() for batch in range(int(valid_source._size // comm.n_procs)): x, y = valid_iter.next() dur = int(valid_source.sample_rate * args.valid_dur) sp, cnt = 0, 0 loss_tmp = nn.NdArray() loss_tmp.zero() while 1: vmixture_audio.d = x[Ellipsis, sp:sp + dur] vtarget_audio.d = y[Ellipsis, sp:sp + dur] vloss.forward(clear_no_need_grad=True) cnt += 1 sp += dur loss_tmp += vloss.data if x[Ellipsis, sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur: break loss_tmp = loss_tmp / cnt if comm.n_procs > 1: comm.all_reduce(loss_tmp, division=True, inplace=True) vlosses.update(loss_tmp.data.copy(), 1) validation_loss = vlosses.avg # clear cache memory ext.clear_memory_cache() lr = lr_scheduler.update_lr(validation_loss, epoch=epoch) solver.set_learning_rate(lr) stop = es.step(validation_loss) if comm.rank == 0: monitor_best_epoch.add(epoch, best_epoch) monitor_traing_loss.add(epoch, training_loss) monitor_validation_loss.add(epoch, validation_loss) monitor_lr.add(epoch, lr) monitor_time.add(epoch) if validation_loss == es.best: # save best model nn.save_parameters(os.path.join(args.output, 'best_xumx.h5')) best_epoch = epoch if stop: print("Apply Early Stopping") break
def train_model(args): print(args) print("generating config") config = Config( input_dim=args.input_dim, dropout=args.dropout, highway=args.highway, nn_layers=args.nn_layers, ) model_name = ".".join( (args.model_file, str(args.rl_baseline_method), args.sampling_method, "gamma", str(args.gamma), "beta", str(args.beta), "batch", str(args.train_batch), "learning_rate", str(args.lr) + "-" + str(args.lr_sch), "bsz", str(args.batch_size), "data", args.data_dir.split('/')[0], args.eval_data, "input_dim", str(config.input_dim), "max_num", str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout", str(args.dropout) + "-" + str(args.clip_grad), "highway", str(args.highway), "nn-" + str(args.nn_layers), 'ans')) log_name = ".".join( ("log_bert/model", str(args.rl_baseline_method), args.sampling_method, "gamma", str(args.gamma), "beta", str(args.beta), "batch", str(args.train_batch), "lr", str(args.lr) + "-" + str(args.lr_sch), "bsz", str(args.batch_size), "data", args.data_dir.split('/')[0], args.eval_data, "input_dim", str(config.input_dim), "max_num", str(args.max_num_of_ans), "reward", str(args.reward_type), "dropout", str(args.dropout) + "-" + str(args.clip_grad), "highway", str(args.highway), "nn-" + str(args.nn_layers), 'ans')) print("initialising data loader and RL learner") data_loader = PickleReader(args.data_dir) data = args.data_dir.split('/')[0] num_data = 0 if data == "wiki_qa": num_data = 873 elif data == "trec_qa": num_data = 1229 else: assert (1 == 2) # init statistics reward_list = [] loss_list = [] best_eval_reward = 0. model_save_name = model_name bandit = ContextualBandit(b=args.batch_size, rl_baseline_method=args.rl_baseline_method, sample_method=args.sampling_method) print("Loaded the Bandit") bert_cb = model2.BERT_CB(config) print("Loaded the model") bert_cb.cuda() vocab = "vocab" if args.load_ext: model_name = args.model_file print("loading existing model%s" % model_name) bert_cb = torch.load(model_name, map_location=lambda storage, loc: storage) bert_cb.cuda() model_save_name = model_name log_name = "/".join(("log_bert", model_name.split("/")[1])) print("finish loading and evaluate model %s" % model_name) # evaluate.ext_model_eval(extract_net, vocab, args, eval_data="test") best_eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args, args.eval_data)[0] logging.basicConfig(filename='%s.log' % log_name, level=logging.DEBUG, format='%(asctime)s %(levelname)-10s %(message)s') # Loss and Optimizer optimizer_ans = torch.optim.Adam([ param for param in bert_cb.parameters() if param.requires_grad == True ], lr=args.lr, betas=(args.beta, 0.999), weight_decay=1e-6) if args.lr_sch == 1: scheduler = ReduceLROnPlateau(optimizer_ans, 'max', verbose=1, factor=0.9, patience=3, cooldown=3, min_lr=9e-5, epsilon=1e-6) if best_eval_reward: scheduler.step(best_eval_reward, 0) print("init_scheduler") elif args.lr_sch == 2: scheduler = torch.optim.lr_scheduler.CyclicLR( optimizer_ans, args.lr, args.lr_2, step_size_up=3 * int(num_data / args.train_batch), step_size_down=3 * int(num_data / args.train_batch), mode='exp_range', gamma=0.98, cycle_momentum=False) print("starting training") start_time = time.time() n_step = 100 gamma = args.gamma #vocab = "vocab" if num_data < 2000: n_val = int(num_data / (5 * args.train_batch)) else: n_val = int(num_data / (7 * args.train_batch)) with torch.autograd.set_detect_anomaly(True): for epoch in tqdm(range(args.epochs_ext), desc="epoch:"): train_iter = data_loader.chunked_data_reader( "train", data_quota=args.train_example_quota) #-1 step_in_epoch = 0 for dataset in train_iter: for step, contexts in tqdm( enumerate( BatchDataLoader(dataset, batch_size=args.train_batch, shuffle=True))): try: bert_cb.train() step_in_epoch += 1 loss = 0. reward = 0. for context in contexts: # q_a = torch.autograd.Variable(torch.from_numpy(context.features)).cuda() pre_processed, a_len, sorted_id = model2.bert_preprocess( context.answers) q_a = torch.autograd.Variable( pre_processed.type(torch.float)) a_len = torch.autograd.Variable(a_len) outputs = bert_cb(q_a, a_len) context.labels = np.array( context.labels)[sorted_id] if args.prt_inf and np.random.randint(0, 100) == 0: prt = True else: prt = False loss_t, reward_t = bandit.train( outputs, context, max_num_of_ans=args.max_num_of_ans, reward_type=args.reward_type, prt=prt) #print(str(loss_t)+' '+str(len(a_len))) # loss_t = loss_t.view(-1) true_labels = np.zeros(len(context.labels)) gold_labels = np.array(context.labels) true_labels[gold_labels > 0] = 1.0 # ml_loss = F.binary_cross_entropy(outputs.view(-1),torch.tensor(true_labels).type(torch.float).cuda()) ml_loss = F.binary_cross_entropy( outputs.view(-1), torch.tensor(true_labels).type( torch.float).cuda()) loss_e = ((gamma * loss_t) + ((1 - gamma) * ml_loss)) loss_e.backward() loss += loss_e.item() reward += reward_t loss = loss / args.train_batch reward = reward / args.train_batch if prt: print('Probabilities: ', outputs.squeeze().data.cpu().numpy()) print('-' * 80) reward_list.append(reward) loss_list.append(loss) #if isinstance(loss, Variable): # loss.backward() if step % 1 == 0: if args.clip_grad: torch.nn.utils.clip_grad_norm_( bert_cb.parameters(), args.clip_grad) # gradient clipping optimizer_ans.step() optimizer_ans.zero_grad() if args.lr_sch == 2: scheduler.step() logging.info('Epoch %d Step %d Reward %.4f Loss %.4f' % (epoch, step_in_epoch, reward, loss)) except Exception as e: print(e) #print(loss) #print(loss_e) traceback.print_exc() if (step_in_epoch) % n_step == 0 and step_in_epoch != 0: logging.info('Epoch ' + str(epoch) + ' Step ' + str(step_in_epoch) + ' reward: ' + str(np.mean(reward_list)) + ' loss: ' + str(np.mean(loss_list))) reward_list = [] loss_list = [] if (step_in_epoch) % n_val == 0 and step_in_epoch != 0: print("doing evaluation") bert_cb.eval() eval_reward = evaluate.ext_model_eval( bert_cb, vocab, args, args.eval_data) if eval_reward[0] > best_eval_reward: best_eval_reward = eval_reward[0] print( "saving model %s with eval_reward:" % model_save_name, eval_reward) logging.debug("saving model" + str(model_save_name) + "with eval_reward:" + str(eval_reward)) torch.save(bert_cb, model_name) print('epoch ' + str(epoch) + ' reward in validation: ' + str(eval_reward)) logging.debug('epoch ' + str(epoch) + ' reward in validation: ' + str(eval_reward)) logging.debug('time elapsed:' + str(time.time() - start_time)) if args.lr_sch == 1: bert_cb.eval() eval_reward = evaluate.ext_model_eval(bert_cb, vocab, args, args.eval_data) scheduler.step(eval_reward[0], epoch) return bert_cb
def main(): args = parser.parse_args() if args.output: output_base = args.output else: output_base = './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, args.gp, 'f'+str(args.fold)]) output_dir = get_outdir(output_base, 'train', exp_name) train_input_root = os.path.join(args.data) batch_size = args.batch_size num_epochs = args.epochs wav_size = (16000,) num_classes = len(dataset.get_labels()) torch.manual_seed(args.seed) model = model_factory.create_model( args.model, in_chs=1, pretrained=args.pretrained, num_classes=num_classes, drop_rate=args.drop, global_pool=args.gp, checkpoint_path=args.initial_checkpoint) #model.reset_classifier(num_classes=num_classes) dataset_train = dataset.CommandsDataset( root=train_input_root, mode='train', fold=args.fold, wav_size=wav_size, format='spectrogram', ) loader_train = data.DataLoader( dataset_train, batch_size=batch_size, pin_memory=True, shuffle=True, num_workers=args.workers ) dataset_eval = dataset.CommandsDataset( root=train_input_root, mode='validate', fold=args.fold, wav_size=wav_size, format='spectrogram', ) loader_eval = data.DataLoader( dataset_eval, batch_size=args.batch_size, pin_memory=True, shuffle=False, num_workers=args.workers ) train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = validate_loss_fn.cuda() opt_params = list(model.parameters()) if args.opt.lower() == 'sgd': optimizer = optim.SGD( opt_params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) elif args.opt.lower() == 'adam': optimizer = optim.Adam( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'nadam': optimizer = nadam.Nadam( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'adadelta': optimizer = optim.Adadelta( opt_params, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) elif args.opt.lower() == 'rmsprop': optimizer = optim.RMSprop( opt_params, lr=args.lr, alpha=0.9, eps=args.opt_eps, momentum=args.momentum, weight_decay=args.weight_decay) else: assert False and "Invalid optimizer" del opt_params if not args.decay_epochs: print('No decay epoch set, using plateau scheduler.') lr_scheduler = ReduceLROnPlateau(optimizer, patience=10) else: lr_scheduler = None # optionally resume from a checkpoint start_epoch = 0 if args.start_epoch is None else args.start_epoch if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if 'args' in checkpoint: print(checkpoint['args']) new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): if k.startswith('module'): name = k[7:] # remove `module.` else: name = k new_state_dict[name] = v model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) if 'loss' in checkpoint: train_loss_fn.load_state_dict(checkpoint['loss']) print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch else: model.load_state_dict(checkpoint) else: print("=> no checkpoint found at '{}'".format(args.resume)) exit(1) saver = CheckpointSaver(checkpoint_dir=output_dir) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() # Optional fine-tune of only the final classifier weights for specified number of epochs (or part of) if not args.resume and args.ft_epochs > 0.: if isinstance(model, torch.nn.DataParallel): classifier_params = model.module.get_classifier().parameters() else: classifier_params = model.get_classifier().parameters() if args.opt.lower() == 'adam': finetune_optimizer = optim.Adam( classifier_params, lr=args.ft_lr, weight_decay=args.weight_decay) else: finetune_optimizer = optim.SGD( classifier_params, lr=args.ft_lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) finetune_epochs_int = int(np.ceil(args.ft_epochs)) finetune_final_batches = int(np.ceil((1 - (finetune_epochs_int - args.ft_epochs)) * len(loader_train))) print(finetune_epochs_int, finetune_final_batches) for fepoch in range(0, finetune_epochs_int): if fepoch == finetune_epochs_int - 1 and finetune_final_batches: batch_limit = finetune_final_batches else: batch_limit = 0 train_epoch( fepoch, model, loader_train, finetune_optimizer, train_loss_fn, args, output_dir=output_dir, batch_limit=batch_limit) best_loss = None try: for epoch in range(start_epoch, num_epochs): if args.decay_epochs: adjust_learning_rate( optimizer, epoch, initial_lr=args.lr, decay_rate=args.decay_rate, decay_epochs=args.decay_epochs) train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, saver=saver, output_dir=output_dir) # save a recovery in case validation blows up saver.save_recovery({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': train_loss_fn.state_dict(), 'args': args, 'gp': args.gp, }, epoch=epoch + 1, batch_idx=0) step = epoch * len(loader_train) eval_metrics = validate( step, model, loader_eval, validate_loss_fn, args, output_dir=output_dir) if lr_scheduler is not None: lr_scheduler.step(eval_metrics['eval_loss']) rowd = OrderedDict(epoch=epoch) rowd.update(train_metrics) rowd.update(eval_metrics) with open(os.path.join(output_dir, 'summary.csv'), mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if best_loss is None: # first iteration (epoch == 1 can't be used) dw.writeheader() dw.writerow(rowd) # save proper checkpoint with eval metric best_loss = saver.save_checkpoint({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args, 'gp': args.gp, }, epoch=epoch + 1, metric=eval_metrics['eval_loss']) except KeyboardInterrupt: pass print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
def main(): logger.info("Starting training\n\n") sys.stdout.flush() args = get_args() snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth" best_model_path = args.snapshot_prefix + "-best_model.pth" line_img_transforms = imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.InvertBlackWhite(), imagetransforms.ToTensor(), ]) # Setup cudnn benchmarks for faster code torch.backends.cudnn.benchmark = False train_dataset = OcrDataset(args.datadir, "train", line_img_transforms) validation_dataset = OcrDataset(args.datadir, "validation", line_img_transforms) train_dataloader = DataLoader(train_dataset, args.batch_size, num_workers=4, sampler=GroupedSampler(train_dataset, rand=True), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=True) validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) n_epochs = args.nepochs lr_alpha = args.lr snapshot_every_n_iterations = args.snapshot_num_iterations if args.load_from_snapshot is not None: model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot) else: model = CnnOcrModel(num_in_channels=1, input_line_height=args.line_height, lstm_input_dim=args.lstm_input_dim, num_lstm_layers=args.num_lstm_layers, num_lstm_hidden_units=args.num_lstm_units, p_lstm_dropout=0.5, alphabet=train_dataset.alphabet, multigpu=True) # Set training mode on all sub-modules model.train() ctc_loss = CTCLoss().cuda() iteration = 0 best_val_wer = float('inf') optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=args.patience, min_lr=args.min_lr) wer_array = [] cer_array = [] loss_array = [] lr_points = [] iteration_points = [] epoch_size = len(train_dataloader) for epoch in range(1, n_epochs + 1): epoch_start = datetime.datetime.now() # First modify main OCR model for batch in train_dataloader: sys.stdout.flush() iteration += 1 iteration_start = datetime.datetime.now() loss = train(batch, model, ctc_loss, optimizer) elapsed_time = datetime.datetime.now() - iteration_start loss = loss / args.batch_size loss_array.append(loss) logger.info( "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s" % (iteration, iteration % epoch_size, epoch_size, epoch, loss, pretty_print_timespan(elapsed_time))) # Do something with loss, running average, plot to some backend server, etc if iteration % snapshot_every_n_iterations == 0: logger.info("Testing on validation set") val_loss, val_cer, val_wer = test_on_val( validation_dataloader, model, ctc_loss) # Reduce learning rate on plateau early_exit = False lowered_lr = False if scheduler.step(val_wer): lowered_lr = True lr_points.append(iteration / snapshot_every_n_iterations) if scheduler.finished: early_exit = True # for bookeeping only lr_alpha = max(lr_alpha * scheduler.factor, scheduler.min_lr) logger.info( "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" % (val_loss, val_cer, val_wer)) torch.save( { 'iteration': iteration, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_hyper_params': model.get_hyper_params(), 'cur_lr': lr_alpha, 'val_loss': val_loss, 'val_cer': val_cer, 'val_wer': val_wer, 'line_height': args.line_height }, snapshot_path) # plotting lr_change on wer, cer and loss. wer_array.append(val_wer) cer_array.append(val_cer) iteration_points.append(iteration / snapshot_every_n_iterations) if val_wer < best_val_wer: logger.info( "Best model so far, copying snapshot to best model file" ) best_val_wer = val_wer shutil.copyfile(snapshot_path, best_model_path) logger.info("Running WER: %s" % str(wer_array)) logger.info("Done with validation, moving on.") if early_exit: logger.info("Early exit") sys.exit(0) if lowered_lr: logger.info( "Switching to best model parameters before continuing with lower LR" ) weights = torch.load(best_model_path) model.load_state_dict(weights['state_dict']) elapsed_time = datetime.datetime.now() - epoch_start logger.info("\n------------------") logger.info("Done with epoch, elapsed time = %s" % pretty_print_timespan(elapsed_time)) logger.info("------------------\n") #writer.close() logger.info("Done.")
# get some random training images dataiter = iter(trainloader) images, labels = dataiter.next() # show images imshow(torchvision.utils.make_grid(images)) # print labels print(' '.join('%5s' % classes[labels[j]] for j in range(4))) #''' net = Net().cuda() #t1 = net.cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=1) # set up scheduler n_image_total = 0 running_loss = 0.0 is_lr_just_decayed = False shall_stop = False for epoch in range(n_epoch): # loop over the dataset multiple times for i, data in enumerate(trainloader, 0): # get the inputs inputs, labels = data # wrap them in Variable inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda()) # zero the parameter gradients optimizer.zero_grad()
def train(): # Check NNabla version if utils.get_nnabla_version_integer() < 11900: raise ValueError( 'Please update the nnabla version to v1.19.0 or latest version since memory efficiency of core engine is improved in v1.19.0' ) parser, args = get_train_args() # Get context. ctx = get_extension_context(args.context, device_id=args.device_id) comm = CommunicatorWrapper(ctx) nn.set_default_context(comm.ctx) ext = import_extension_module(args.context) # Monitors # setting up monitors for logging monitor_path = args.output monitor = Monitor(monitor_path) monitor_best_epoch = MonitorSeries('Best epoch', monitor, interval=1) monitor_traing_loss = MonitorSeries('Training loss', monitor, interval=1) monitor_validation_loss = MonitorSeries('Validation loss', monitor, interval=1) monitor_lr = MonitorSeries('learning rate', monitor, interval=1) monitor_time = MonitorTimeElapsed("training time per iteration", monitor, interval=1) if comm.rank == 0: if not os.path.isdir(args.output): os.makedirs(args.output) # Initialize DataIterator for MUSDB18. train_source, valid_source, args = load_datasources(parser, args) train_iter = data_iterator( train_source, args.batch_size, RandomState(args.seed), with_memory_cache=False, ) valid_iter = data_iterator( valid_source, 1, RandomState(args.seed), with_memory_cache=False, ) if comm.n_procs > 1: train_iter = train_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) valid_iter = valid_iter.slice(rng=None, num_of_slices=comm.n_procs, slice_pos=comm.rank) # Calculate maxiter per GPU device. # Change max_iter, learning_rate and weight_decay according no. of gpu devices for multi-gpu training. default_batch_size = 16 train_scale_factor = (comm.n_procs * args.batch_size) / default_batch_size max_iter = int((train_source._size // args.batch_size) // comm.n_procs) weight_decay = args.weight_decay * train_scale_factor args.lr = args.lr * train_scale_factor # Calculate the statistics (mean and variance) of the dataset scaler_mean, scaler_std = utils.get_statistics(args, train_source) # clear cache memory ext.clear_memory_cache() max_bin = utils.bandwidth_to_max_bin(train_source.sample_rate, args.nfft, args.bandwidth) # Get X-UMX/UMX computation graph and variables as namedtuple model = get_model(args, scaler_mean, scaler_std, max_bin=max_bin) # Create Solver and set parameters. solver = S.Adam(args.lr) solver.set_parameters(nn.get_parameters()) # Initialize Early Stopping es = utils.EarlyStopping(patience=args.patience) # Initialize LR Scheduler (ReduceLROnPlateau) lr_scheduler = ReduceLROnPlateau(lr=args.lr, factor=args.lr_decay_gamma, patience=args.lr_decay_patience) best_epoch = 0 # AverageMeter for mean loss calculation over the epoch losses = utils.AverageMeter() # Training loop. for epoch in trange(args.epochs): # TRAINING losses.reset() for batch in range(max_iter): model.mixture_audio.d, model.target_audio.d = train_iter.next() solver.zero_grad() model.loss.forward(clear_no_need_grad=True) if comm.n_procs > 1: all_reduce_callback = comm.get_all_reduce_callback() model.loss.backward(clear_buffer=True, communicator_callbacks=all_reduce_callback) else: model.loss.backward(clear_buffer=True) solver.weight_decay(weight_decay) solver.update() losses.update(model.loss.d.copy(), args.batch_size) training_loss = losses.get_avg() # clear cache memory ext.clear_memory_cache() # VALIDATION losses.reset() for batch in range(int(valid_source._size // comm.n_procs)): x, y = valid_iter.next() dur = int(valid_source.sample_rate * args.valid_dur) sp, cnt = 0, 0 loss_tmp = nn.NdArray() loss_tmp.zero() while 1: model.vmixture_audio.d = x[Ellipsis, sp:sp + dur] model.vtarget_audio.d = y[Ellipsis, sp:sp + dur] model.vloss.forward(clear_no_need_grad=True) cnt += 1 sp += dur loss_tmp += model.vloss.data if x[Ellipsis, sp:sp + dur].shape[-1] < dur or x.shape[-1] == cnt * dur: break loss_tmp = loss_tmp / cnt if comm.n_procs > 1: comm.all_reduce(loss_tmp, division=True, inplace=True) losses.update(loss_tmp.data.copy(), 1) validation_loss = losses.get_avg() # clear cache memory ext.clear_memory_cache() lr = lr_scheduler.update_lr(validation_loss, epoch=epoch) solver.set_learning_rate(lr) stop = es.step(validation_loss) if comm.rank == 0: monitor_best_epoch.add(epoch, best_epoch) monitor_traing_loss.add(epoch, training_loss) monitor_validation_loss.add(epoch, validation_loss) monitor_lr.add(epoch, lr) monitor_time.add(epoch) if validation_loss == es.best: best_epoch = epoch # save best model if args.umx_train: nn.save_parameters(os.path.join(args.output, 'best_umx.h5')) else: nn.save_parameters( os.path.join(args.output, 'best_xumx.h5')) if args.umx_train: # Early stopping for UMX after `args.patience` (140) number of epochs if stop: print("Apply Early Stopping") break
dataset_loaders = { 'train':DataLoader(Controller(train_df), DICOMPreprocessor(augment=True)), 'val':DataLoader(Controller(val_df), DICOMPreprocessor(augment=True)) } dataset_sizes = { 'train':dataset_loaders['train'].shape(), 'val':dataset_loaders['val'].shape() } RLRP_agent = ReduceLROnPlateau('min') num_epochs = 5 best_model = train_model(args, model, criterion, dataset_loaders, dataset_sizes, RLRP_agent, num_epochs) print(best_model)
def main(): logger.info("Starting training\n\n") sys.stdout.flush() args = get_args() snapshot_path = args.snapshot_prefix + "-cur_snapshot.pth" best_model_path = args.snapshot_prefix + "-best_model.pth" line_img_transforms = [] #if args.num_in_channels == 3: # line_img_transforms.append(imagetransforms.ConvertColor()) # Always convert color for the augmentations to work (for now) # Then alter convert back to grayscale if needed line_img_transforms.append(imagetransforms.ConvertColor()) # Data augmentations (during training only) if args.daves_augment: line_img_transforms.append(daves_augment.ImageAug()) if args.synth_input: # Randomly rotate image from -2 degrees to +2 degrees line_img_transforms.append( imagetransforms.Randomize(0.3, imagetransforms.RotateRandom(-2, 2))) # Choose one of methods to blur/pixel-ify image (or don't and choose identity) line_img_transforms.append( imagetransforms.PickOne([ imagetransforms.TessBlockConv(kernel_val=1, bias_val=1), imagetransforms.TessBlockConv(rand=True), imagetransforms.Identity(), ])) aug_cn = iaa.ContrastNormalization((0.5, 2.0), per_channel=0.5) line_img_transforms.append( imagetransforms.Randomize(0.5, lambda x: aug_cn.augment_image(x))) # With some probability, choose one of: # Grayscale: convert to grayscale and add back into color-image with random alpha # Emboss: Emboss image with random strength # Invert: Invert colors of image per-channel aug_gray = iaa.Grayscale(alpha=(0.0, 1.0)) aug_emboss = iaa.Emboss(alpha=(0, 1.0), strength=(0, 2.0)) aug_invert = iaa.Invert(1, per_channel=True) aug_invert2 = iaa.Invert(0.1, per_channel=False) line_img_transforms.append( imagetransforms.Randomize( 0.3, imagetransforms.PickOne([ lambda x: aug_gray.augment_image(x), lambda x: aug_emboss.augment_image(x), lambda x: aug_invert.augment_image(x), lambda x: aug_invert2.augment_image(x) ]))) # Randomly try to crop close to top/bottom and left/right of lines # For now we are just guessing (up to 5% of ends and up to 10% of tops/bottoms chopped off) if args.tight_crop: # To make sure padding is reasonably consistent, we first rsize image to target line height # Then add padding to this version of image # Below it will get resized again to target line height line_img_transforms.append( imagetransforms.Randomize( 0.9, imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.PadRandom(pxl_max_horizontal=30, pxl_max_vertical=10) ]))) else: line_img_transforms.append( imagetransforms.Randomize(0.2, imagetransforms.CropHorizontal(.05))) line_img_transforms.append( imagetransforms.Randomize(0.2, imagetransforms.CropVertical(.1))) #line_img_transforms.append(imagetransforms.Randomize(0.2, # imagetransforms.PickOne([imagetransforms.MorphErode(3), imagetransforms.MorphDilate(3)]) # )) # Make sure to do resize after degrade step above line_img_transforms.append(imagetransforms.Scale(new_h=args.line_height)) if args.cvtGray: line_img_transforms.append(imagetransforms.ConvertGray()) # Only do for grayscale if args.num_in_channels == 1: line_img_transforms.append(imagetransforms.InvertBlackWhite()) if args.stripe: line_img_transforms.append( imagetransforms.Randomize( 0.3, imagetransforms.AddRandomStripe(val=0, strip_width_from=1, strip_width_to=4))) line_img_transforms.append(imagetransforms.ToTensor()) line_img_transforms = imagetransforms.Compose(line_img_transforms) # Setup cudnn benchmarks for faster code torch.backends.cudnn.benchmark = False if len(args.datadir) == 1: train_dataset = OcrDataset(args.datadir[0], "train", line_img_transforms) validation_dataset = OcrDataset(args.datadir[0], "validation", line_img_transforms) else: train_dataset = OcrDatasetUnion(args.datadir, "train", line_img_transforms) validation_dataset = OcrDatasetUnion(args.datadir, "validation", line_img_transforms) if args.test_datadir is not None: if args.test_outdir is None: print( "Error, must specify both --test-datadir and --test-outdir together" ) sys.exit(1) if not os.path.exists(args.test_outdir): os.makedirs(args.test_outdir) line_img_transforms_test = imagetransforms.Compose([ imagetransforms.Scale(new_h=args.line_height), imagetransforms.ToTensor() ]) test_dataset = OcrDataset(args.test_datadir, "test", line_img_transforms_test) n_epochs = args.nepochs lr_alpha = args.lr snapshot_every_n_iterations = args.snapshot_num_iterations if args.load_from_snapshot is not None: model = CnnOcrModel.FromSavedWeights(args.load_from_snapshot) print( "Overriding automatically learned alphabet with pre-saved model alphabet" ) if len(args.datadir) == 1: train_dataset.alphabet = model.alphabet validation_dataset.alphabet = model.alphabet else: train_dataset.alphabet = model.alphabet validation_dataset.alphabet = model.alphabet for ds in train_dataset.datasets: ds.alphabet = model.alphabet for ds in validation_dataset.datasets: ds.alphabet = model.alphabet else: model = CnnOcrModel(num_in_channels=args.num_in_channels, input_line_height=args.line_height, rds_line_height=args.rds_line_height, lstm_input_dim=args.lstm_input_dim, num_lstm_layers=args.num_lstm_layers, num_lstm_hidden_units=args.num_lstm_units, p_lstm_dropout=0.5, alphabet=train_dataset.alphabet, multigpu=True) # Setting dataloader after we have a chnae to (maybe!) over-ride the dataset alphabet from a pre-trained model train_dataloader = DataLoader(train_dataset, args.batch_size, num_workers=4, sampler=GroupedSampler(train_dataset, rand=True), collate_fn=SortByWidthCollater, pin_memory=True, drop_last=True) if args.max_val_size > 0: validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, max_items=args.max_val_size, fixed_rand=True), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) else: validation_dataloader = DataLoader(validation_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler( validation_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) if args.test_datadir is not None: test_dataloader = DataLoader(test_dataset, args.batch_size, num_workers=0, sampler=GroupedSampler(test_dataset, rand=False), collate_fn=SortByWidthCollater, pin_memory=False, drop_last=False) # Set training mode on all sub-modules model.train() ctc_loss = CTCLoss().cuda() iteration = 0 best_val_wer = float('inf') optimizer = torch.optim.Adam(model.parameters(), lr=lr_alpha, weight_decay=args.weight_decay) scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=args.patience, min_lr=args.min_lr) wer_array = [] cer_array = [] loss_array = [] lr_points = [] iteration_points = [] epoch_size = len(train_dataloader) do_test_write = False for epoch in range(1, n_epochs + 1): epoch_start = datetime.datetime.now() # First modify main OCR model for batch in train_dataloader: sys.stdout.flush() iteration += 1 iteration_start = datetime.datetime.now() loss = train(batch, model, ctc_loss, optimizer) elapsed_time = datetime.datetime.now() - iteration_start loss = loss / args.batch_size loss_array.append(loss) logger.info( "Iteration: %d (%d/%d in epoch %d)\tLoss: %f\tElapsed Time: %s" % (iteration, iteration % epoch_size, epoch_size, epoch, loss, pretty_print_timespan(elapsed_time))) # Only turn on test-on-testset when cer is starting to get non-random if iteration % snapshot_every_n_iterations == 0: logger.info("Testing on validation set") val_loss, val_cer, val_wer = test_on_val( validation_dataloader, model, ctc_loss) if val_cer < 0.5: do_test_write = True if args.test_datadir is not None and ( iteration % snapshot_every_n_iterations == 0) and do_test_write: out_hyp_outdomain_file = os.path.join( args.test_outdir, "hyp-%07d.outdomain.utf8" % iteration) out_hyp_indomain_file = os.path.join( args.test_outdir, "hyp-%07d.indomain.utf8" % iteration) out_meta_file = os.path.join(args.test_outdir, "hyp-%07d.meta" % iteration) test_on_val_writeout(test_dataloader, model, out_hyp_outdomain_file) test_on_val_writeout(validation_dataloader, model, out_hyp_indomain_file) with open(out_meta_file, 'w') as fh_out: fh_out.write("%d,%f,%f,%f\n" % (iteration, val_cer, val_wer, val_loss)) # Reduce learning rate on plateau early_exit = False lowered_lr = False if scheduler.step(val_wer): lowered_lr = True lr_points.append(iteration / snapshot_every_n_iterations) if scheduler.finished: early_exit = True # for bookeeping only lr_alpha = max(lr_alpha * scheduler.factor, scheduler.min_lr) logger.info( "Val Loss: %f\tNo LM Val CER: %f\tNo LM Val WER: %f" % (val_loss, val_cer, val_wer)) torch.save( { 'iteration': iteration, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'model_hyper_params': model.get_hyper_params(), 'rtl': args.rtl, 'cur_lr': lr_alpha, 'val_loss': val_loss, 'val_cer': val_cer, 'val_wer': val_wer, 'line_height': args.line_height }, snapshot_path) # plotting lr_change on wer, cer and loss. wer_array.append(val_wer) cer_array.append(val_cer) iteration_points.append(iteration / snapshot_every_n_iterations) if val_wer < best_val_wer: logger.info( "Best model so far, copying snapshot to best model file" ) best_val_wer = val_wer shutil.copyfile(snapshot_path, best_model_path) logger.info("Running WER: %s" % str(wer_array)) logger.info("Done with validation, moving on.") if early_exit: logger.info("Early exit") sys.exit(0) if lowered_lr: logger.info( "Switching to best model parameters before continuing with lower LR" ) weights = torch.load(best_model_path) model.load_state_dict(weights['state_dict']) elapsed_time = datetime.datetime.now() - epoch_start logger.info("\n------------------") logger.info("Done with epoch, elapsed time = %s" % pretty_print_timespan(elapsed_time)) logger.info("------------------\n") #writer.close() logger.info("Done.")
def train(args, model, processor): tokenizer = BertTokenizer.from_pretrained( './BERT_model/bert_pretrain/vocab.txt') train_dataset = load_and_cache_examples(args, processor, data_type='train') train_loader = DatasetLoader(data=train_dataset, batch_size=args.batch_size, shuffle=False, seed=args.seed, sort=True, vocab=processor.vocab, label2id=args.label2id, tokenizer=tokenizer) # train_loader = DatasetLoader(data=train_dataset, batch_size=args.batch_size, # shuffle=False, seed=args.seed, sort=True, # vocab=processor.vocab, label2id=args.label2id) parameters = [p for p in model.parameters() if p.requires_grad] optimizer = optim.Adam(parameters, lr=args.learning_rate) scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=1, epsilon=1e-4, cooldown=0, min_lr=0, eps=1e-8) train_metric = SeqEntityScore(args.id2label, markup=args.markup) best_f1 = 0 for epoch in range(1, 1 + args.epochs): strat_epoch_time = time.time() logger.info(f"Epoch {epoch}/{args.epochs}") #pbar = ProgressBar(n_total=len(train_loader), desc='Training') #进度条样式 train_loss = AverageMeter() model.train() assert model.training for step, batch in enumerate(train_loader): strat_batch_time = time.time() input_ids, input_mask, input_tags, input_lens = batch input_ids = input_ids.to(args.device) input_mask = input_mask.to(args.device) input_tags = input_tags.to(args.device) features, loss = model.forward_loss(input_ids, input_mask, input_lens, input_tags) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm) optimizer.step() optimizer.zero_grad() # pbar(step=step, info={'loss': loss.item()}) train_loss.update(loss.item(), n=1) tags, _ = model.crf._obtain_labels(features, args.id2label, input_lens) input_tags = input_tags.cpu().numpy() target = [ input_[:len_] for input_, len_ in zip(input_tags, input_lens) ] pre_train = train_metric.compute_train_pre(label_paths=target, pred_paths=tags) logger.info( f'time: {time.time() - strat_batch_time:.1f} train_loss: {loss.item():.4f} train_pre: {pre_train:.4f}' ) print(" ") logger.info(f'train_total_time: {time.time() - strat_epoch_time}') if 'cuda' in str(args.device): torch.cuda.empty_cache() # 释放显存 strat_eval_time = time.time() eval_f1 = evaluate(args, model, processor) show_info = f'eval_time: {time.time() - strat_eval_time:.1f} train_avg_loss: {train_loss.avg:.4f} eval_f1: {eval_f1:.4f} ' logger.info(show_info) scheduler.epoch_step(eval_f1, epoch) if eval_f1 > best_f1: # Epoch 1: eval_f1 improved from 0 to 0.4023105674481821 logger.info( f"\nEpoch {epoch}: eval_f1 improved from {best_f1} to {eval_f1}" ) best_f1 = eval_f1 model_stat_dict = model.state_dict() state = { 'epoch': epoch, 'arch': args.arch, 'state_dict': model_stat_dict } model_path = args.output_dir / 'best-model.bin' torch.save(state, str(model_path))