def train(epoch): losses = AverageMeter() # switch to train mode model.train() if args.distribute: train_sampler.set_epoch(epoch) correct = 0 preds = [] train_labels = [] for i, (image, label) in enumerate(train_loader): rate = get_learning_rate(optimizer) image, label = image.cuda(), label.cuda() output = model(image) loss = criterion(output, label) optimizer.zero_grad() loss.backward() optimizer.step() # measure accuracy and record loss losses.update(loss.item(), image.size(0)) if i % args.print_freq == 0 or i == len(train_loader) - 1: print('Epoch: [{0}][{1}/{2}]\t' 'Rate:{rate}\t' 'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format( epoch, i, len(train_loader), rate=rate, loss=losses)) return
def main(config): save_path = config['save_path'] epochs = config['epochs'] os.environ['TORCH_HOME'] = config['torch_home'] distributed = config['use_DDP'] start_ep = 0 start_cnt = 0 # initialize model print("Initializing model...") if distributed: initialize_distributed(config) rank = config['rank'] # map string name to class constructor model = get_model(config) model.apply(init_weights) if config['resume_ckpt'] is not None: # load weights from checkpoint state_dict = load_weights(config['resume_ckpt']) model.load_state_dict(state_dict) print("Moving model to GPU") model.cuda(torch.cuda.current_device()) print("Setting up losses") if config['use_vgg']: criterionVGG = Vgg19PerceptualLoss(config['reduced_w']) criterionVGG.cuda() validationLoss = criterionVGG if config['use_gan']: use_sigmoid = config['no_lsgan'] disc_input_channels = 3 discriminator = MultiscaleDiscriminator(disc_input_channels, config['ndf'], config['n_layers_D'], 'instance', use_sigmoid, config['num_D'], False, False) discriminator.apply(init_weights) if config['resume_ckpt_D'] is not None: # load weights from checkpoint print("Resuming discriminator from %s" % (config['resume_ckpt_D'])) state_dict = load_weights(config['resume_ckpt_D']) discriminator.load_state_dict(state_dict) discriminator.cuda(torch.cuda.current_device()) criterionGAN = GANLoss(use_lsgan=not config['no_lsgan']) criterionGAN.cuda() criterionFeat = nn.L1Loss().cuda() if config['use_l2']: criterionMSE = nn.MSELoss() criterionMSE.cuda() validationLoss = criterionMSE # initialize dataloader print("Setting up dataloaders...") train_dataloader, val_dataloader, train_sampler = setup_dataloaders(config) print("Done!") # run the training loop print("Initializing optimizers...") optimizer_G = optim.Adam(model.parameters(), lr=config['learning_rate'], weight_decay=config['weight_decay']) if config['resume_ckpt_opt_G'] is not None: optimizer_G_state_dict = torch.load( config['resume_ckpt_opt_G'], map_location=lambda storage, loc: storage) optimizer_G.load_state_dict(optimizer_G_state_dict) if config['use_gan']: optimizer_D = optim.Adam(discriminator.parameters(), lr=config['learning_rate']) if config['resume_ckpt_opt_D'] is not None: optimizer_D_state_dict = torch.load( config['resume_ckpt_opt_D'], map_location=lambda storage, loc: storage) optimizer_D.load_state_dict(optimizer_D_state_dict) print("Done!") if distributed: print("Moving model to DDP...") model = DDP(model) if config['use_gan']: discriminator = DDP(discriminator, delay_allreduce=True) print("Done!") tb_logger = None if rank == 0: tb_logdir = os.path.join(save_path, 'tbdir') if not os.path.exists(tb_logdir): os.makedirs(tb_logdir) tb_logger = SummaryWriter(tb_logdir) # run training if not os.path.exists(save_path): os.makedirs(save_path) log_name = os.path.join(save_path, 'loss_log.txt') opt_name = os.path.join(save_path, 'opt.yaml') print(config) save_options(opt_name, config) log_handle = open(log_name, 'a') print("Starting training") cnt = start_cnt assert (config['use_warped'] or config['use_temporal']) for ep in range(start_ep, epochs): if train_sampler is not None: train_sampler.set_epoch(ep) for curr_batch in train_dataloader: optimizer_G.zero_grad() input_a = curr_batch['input_a'].cuda() target = curr_batch['target'].cuda() if config['use_warped'] and config['use_temporal']: input_a = torch.cat((input_a, input_a), 0) input_b = torch.cat((curr_batch['input_b'].cuda(), curr_batch['input_temporal'].cuda()), 0) target = torch.cat((target, target), 0) elif config['use_temporal']: input_b = curr_batch['input_temporal'].cuda() elif config['use_warped']: input_b = curr_batch['input_b'].cuda() output_dict = model(input_a, input_b) output_recon = output_dict['reconstruction'] loss_vgg = loss_G_GAN = loss_G_feat = loss_l2 = 0 if config['use_vgg']: loss_vgg = criterionVGG(output_recon, target) * config['vgg_lambda'] if config['use_gan']: predicted_landmarks = output_dict['input_a_gauss_maps'] # output_dict['reconstruction'] can be considered normalized loss_G_GAN, loss_D_real, loss_D_fake = apply_GAN_criterion( output_recon, target, predicted_landmarks.detach(), discriminator, criterionGAN) loss_D = (loss_D_fake + loss_D_real) * 0.5 if config['use_l2']: loss_l2 = criterionMSE(output_recon, target) * config['l2_lambda'] loss_G = loss_G_GAN + loss_G_feat + loss_vgg + loss_l2 loss_G.backward() # grad_norm clipping if not config['no_grad_clip']: torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer_G.step() if config['use_gan']: optimizer_D.zero_grad() loss_D.backward() # grad_norm clipping if not config['no_grad_clip']: torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0) optimizer_D.step() if distributed: if config['use_vgg']: loss_vgg = reduce_tensor(loss_vgg, config['world_size']) if rank == 0: if cnt % 10 == 0: run_visualization(output_dict, output_recon, target, input_a, input_b, save_path, tb_logger, cnt) print_dict = {"learning_rate": get_learning_rate(optimizer_G)} if config['use_vgg']: tb_logger.add_scalar('vgg.loss', loss_vgg, cnt) print_dict['Loss_VGG'] = loss_vgg.data if config['use_gan']: tb_logger.add_scalar('gan.loss', loss_G_GAN, cnt) tb_logger.add_scalar('d_real.loss', loss_D_real, cnt) tb_logger.add_scalar('d_fake.loss', loss_D_fake, cnt) print_dict['Loss_G_GAN'] = loss_G_GAN print_dict['Loss_real'] = loss_D_real.data print_dict['Loss_fake'] = loss_D_fake.data if config['use_l2']: tb_logger.add_scalar('l2.loss', loss_l2, cnt) print_dict['Loss_L2'] = loss_l2.data log_iter(ep, cnt % len(train_dataloader), len(train_dataloader), print_dict, log_handle=log_handle) if loss_G != loss_G: print("NaN!!") exit(-2) cnt = cnt + 1 # end of train iter loop if cnt % config['val_freq'] == 0 and config['val_freq'] > 0: val_loss = run_val( model, validationLoss, val_dataloader, os.path.join(save_path, 'val_%d_renders' % (ep))) if distributed: val_loss = reduce_tensor(val_loss, config['world_size']) if rank == 0: tb_logger.add_scalar('validation.loss', val_loss, cnt) log_iter(ep, cnt % len(train_dataloader), len(train_dataloader), {"Loss_VGG": val_loss}, header="Validation loss: ", log_handle=log_handle) if rank == 0: if (ep % config['save_freq'] == 0): fname = 'checkpoint_%d.ckpt' % (ep) fname = os.path.join(save_path, fname) print("Saving model...") save_weights(model, fname, distributed) optimizer_g_fname = os.path.join( save_path, 'latest_optimizer_g_state.ckpt') torch.save(optimizer_G.state_dict(), optimizer_g_fname) if config['use_gan']: fname = 'checkpoint_D_%d.ckpt' % (ep) fname = os.path.join(save_path, fname) save_weights(discriminator, fname, distributed) optimizer_d_fname = os.path.join( save_path, 'latest_optimizer_d_state.ckpt') torch.save(optimizer_D.state_dict(), optimizer_d_fname)
if args.use_adasum else hvd.Sum) #hvd.Average备选,建议使用sum,lr换算简单 # model, optimizer = amp.initialize(model, optimizer, opt_level="O1") if args.evaluate: exit() for epoch in range(args.start_epoch, args.epochs): epoch_start = time.time() # train for one epoch train(epoch) # evaluate on validation set val_score, val_loss = validate() scheduler.step() lr_rate_1 = get_learning_rate(optimizer) epoch_time = time.time() - epoch_start if (not args.distribute) or (args.distribute and hvd.rank() == 0): print('Epoch[{0}] LR: {lr} Time:{time:.6f} ' 'ValLoss {val_loss:.6f} ' 'Val_Score {val_score:.6f}'.format(epoch, lr=lr_rate_1, time=epoch_time, val_loss=val_loss, val_score=val_score)) is_best = val_score > best_val_score[fold] best_val_score[fold] = max(val_score, best_val_score[fold]) if is_best: if (not args.distribute) or (args.distribute and hvd.rank() == 0): print("--------current best-------:%f" % best_val_score[fold])
def train(args, train_loader, val_loader, model, val_criterion, optimizer, lr_scheduler, epoch, step, tb, max_score, cuda=False): """ Runs the training loop per epoch. dataloader: Data loader for train args: args net: network optimizer: optimizer cur_epoch: current epoch cuda: use gpu or not. """ model.train() train_loss = AverageMeter() pbar = tqdm(total=len(train_loader), desc="train_model") for idx, data_batch in enumerate(train_loader): images, targets, img_names = data_batch if cuda: images, targets = images.cuda(), targets.cuda() inputs = {"images": images, "gts": targets} loss = model(inputs) optimizer.zero_grad() loss.backward() optimizer.step() train_loss.update(loss.item(), n=1) if (step + 1) % args.log_step == 0: tb.scalar_summary("model/loss", loss.data, step) tb.scalar_summary("model/lr", args.lr, step) pbar.set_description(desc=f"train_model| loss: {loss.item():5.3f}") if (step + 1) % args.val_freq == 0: val_scores = validation(args, val_loader, model, val_criterion, step, cuda=args.cuda) logger.info( f"| model_name {args.model_name} | step: {step} | PA: {val_scores['PA']} " f"| mPA: {val_scores['MPA']} | mIoU: {val_scores['MIOU']} | FWIoU: {val_scores['FWIOU']}" ) logger.info( f"| model_name {args.model_name} | step: {step} | IOU: {val_scores['IOU']}" ) tb.scalar_summary("val/PA", val_scores["PA"], step) tb.scalar_summary("val/mPA", val_scores["MPA"], step) tb.scalar_summary("val/mIoU", val_scores["MIOU"], step) tb.scalar_summary("val/FWIoU", val_scores["FWIOU"], step) for c, iou_c in enumerate(val_scores["IOU"]): tb.scalar_summary(f"val/IOU_{cfg.DATASET.TRAINID_TO_ID[c]}", iou_c, step) max_score = max(max_score, val_scores["FWIOU"]) logger.info(f"[*] Step: {step}, max_score: {max_score}.") if args.lr_schedule == "reduce_lr_on_plateau": lr_scheduler.step(val_scores["FWIOU"]) else: lr_scheduler.step() args.lr = get_learning_rate(optimizer)[0] state_dict = { "epoch": epoch, "step": step, "state_dict": model.state_dict(), "max_score": max_score, "optimizer": optimizer.state_dict() } # save_checkpoint(state_dict, step, is_best, args) save_model(state_dict, step, args, val_scores, max_save_num=5, save_criterion="FWIOU") step += 1 pbar.update(1) return train_loss.avg, step, max_score
def main(args): source_train_set = custom_dataset(args.train_data_path, args.train_gt_path) valid_train_set = valid_dataset(args.val_data_path, args.val_gt_path, data_flag='ic13') source_train_loader = data.DataLoader(source_train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=True) valid_loader = data.DataLoader(valid_train_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) criterion = Loss().to(device) best_loss = 1000 best_num = 0 model = EAST() if args.pretrained_model_path: model.load_state_dict(torch.load(args.pretrained_model_path)) # resume if args.resume: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint['state_dict']) best_loss = checkpoint['best_loss'] current_epoch_num = checkpoint['epoch'] data_parallel = False if torch.cuda.device_count() > 1: model = nn.DataParallel(model) data_parallel = True model.to(device) total_epoch = args.epochs optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[total_epoch // 3, total_epoch * 2 // 3], gamma=0.1) scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, threshold=args.lr / 100) current_epoch_num = 0 # resume if args.resume: checkpoint = torch.load(args.resume) scheduler.load_state_dict(checkpoint['scheduler']) for epoch in range(current_epoch_num, total_epoch): each_epoch_start = time.time() # scheduler.step(epoch) # add lr in tensorboardX writer.add_scalar('epoch/lr', get_learning_rate(optimizer), epoch) train(source_train_loader, model, criterion, optimizer, epoch) val_loss = eval(model, valid_loader, criterion, epoch) scheduler.step(val_loss) if val_loss < best_loss: best_num = epoch + 1 best_loss = val_loss best_model_wts = copy.deepcopy(model.module.state_dict( ) if data_parallel else model.state_dict()) # save best model torch.save( { 'epoch': epoch + 1, 'state_dict': best_model_wts, 'best_loss': best_loss, 'scheduler': scheduler.state_dict(), }, os.path.join(save_folder, "model_epoch_best.pth")) log.write('best model num:{}, best loss is {:.8f}'.format( best_num, best_loss)) log.write('\n') if (epoch + 1) % int(args.save_interval) == 0: state_dict = model.module.state_dict( ) if data_parallel else model.state_dict() torch.save( { 'epoch': epoch + 1, 'state_dict': state_dict, 'best_loss': best_loss, 'scheduler': scheduler.state_dict(), }, os.path.join(save_folder, 'model_epoch_{}.pth'.format(epoch + 1))) log.write('save model') log.write('\n') log.write('=' * 50) log.write('\n')
def train_net(args): torch.manual_seed(7) np.random.seed(7) checkpoint = args.checkpoint start_epoch = 0 best_loss = float('inf') writer = SummaryWriter() epochs_since_improvement = 0 # Initialize / load checkpoint if checkpoint is None: model = torchvision.models.detection.keypointrcnn_resnet50_fpn( pretrained=False, progress=True, num_classes=2, num_keypoints=14, pretrained_backbone=True) model = nn.DataParallel(model) if args.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.mom, weight_decay=args.weight_decay) else: optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: checkpoint = torch.load(checkpoint) start_epoch = checkpoint['epoch'] + 1 epochs_since_improvement = checkpoint['epochs_since_improvement'] model = checkpoint['model'] optimizer = checkpoint['optimizer'] logger = get_logger() # Move to GPU, if available model = model.to(device) # Custom dataloaders train_dataset = KpDataset('train') train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1) valid_dataset = KpDataset('valid') valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=1) # Epochs for epoch in range(start_epoch, args.end_epoch): if epochs_since_improvement == 10: break if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0: adjust_learning_rate(optimizer, 0.6) # One epoch's training train_loss = train(train_loader=train_loader, model=model, optimizer=optimizer, epoch=epoch, logger=logger) effective_lr = get_learning_rate(optimizer) print('Current effective learning rate: {}\n'.format(effective_lr)) writer.add_scalar('Train_Loss', train_loss, epoch) # One epoch's validation valid_loss = valid(valid_loader=valid_loader, model=model, logger=logger) writer.add_scalar('Valid_Loss', valid_loss, epoch) # Check if there was an improvement is_best = valid_loss < best_loss best_loss = min(valid_loss, best_loss) if not is_best: epochs_since_improvement += 1 print("\nEpochs since last improvement: %d\n" % (epochs_since_improvement, )) else: epochs_since_improvement = 0 # Save checkpoint save_checkpoint(epoch, epochs_since_improvement, model, optimizer, best_loss, is_best)
optimizer.load_state_dict(torch.load("{}".format(os.path.join(load_dir, 'network.optimizer.epoch{}'.format(hp.loaded_epoch))))) dataloader = DataLoader(dataset_train, batch_sampler=sampler, num_workers=1, collate_fn=collate_fn_transformer) step = hp.loaded_epoch * len(dataloader) else: start_epoch = 0 step = 1 for epoch in range(start_epoch, hp.max_epoch): dataloader = DataLoader(dataset_train, batch_sampler=sampler, num_workers=4, collate_fn=collate_fn_transformer) #pbar = tqdm(dataloader) #for d in pbar: for d in dataloader: if hp.optimizer.lower() != 'radam': lr = get_learning_rate(step, hp.d_model_decoder, hp.warmup_factor, hp.warmup_step) for param_group in optimizer.param_groups: param_group['lr'] = lr text, mel, pos_text, pos_mel, text_lengths, mel_lengths, stop_token, spk_emb, f0, energy, alignment = d text = text.to(DEVICE, non_blocking=True) mel = mel.to(DEVICE, non_blocking=True) pos_text = pos_text.to(DEVICE, non_blocking=True) pos_mel = pos_mel.to(DEVICE, non_blocking=True) mel_lengths = mel_lengths.to(DEVICE, non_blocking=True) text_lengths = text_lengths.to(DEVICE, non_blocking=True) stop_token = stop_token.to(DEVICE, non_blocking=True) if hp.is_multi_speaker: spk_emb = spk_emb.to(DEVICE, non_blocking=True) if hp.pitch_pred: