示例#1
0
def train(epoch):
    losses = AverageMeter()
    # switch to train mode
    model.train()
    if args.distribute:
        train_sampler.set_epoch(epoch)
    correct = 0
    preds = []
    train_labels = []
    for i, (image, label) in enumerate(train_loader):
        rate = get_learning_rate(optimizer)
        image, label = image.cuda(), label.cuda()

        output = model(image)
        loss = criterion(output, label)

        optimizer.zero_grad()

        loss.backward()
        optimizer.step()

        # measure accuracy and record loss
        losses.update(loss.item(), image.size(0))
        if i % args.print_freq == 0 or i == len(train_loader) - 1:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Rate:{rate}\t'
                  'Loss {loss.val:.5f} ({loss.avg:.5f})\t'.format(
                      epoch, i, len(train_loader), rate=rate, loss=losses))

    return
def main(config):
    save_path = config['save_path']
    epochs = config['epochs']
    os.environ['TORCH_HOME'] = config['torch_home']
    distributed = config['use_DDP']
    start_ep = 0
    start_cnt = 0

    # initialize model
    print("Initializing model...")
    if distributed:
        initialize_distributed(config)
    rank = config['rank']

    # map string name to class constructor
    model = get_model(config)
    model.apply(init_weights)
    if config['resume_ckpt'] is not None:
        # load weights from checkpoint
        state_dict = load_weights(config['resume_ckpt'])
        model.load_state_dict(state_dict)

    print("Moving model to GPU")
    model.cuda(torch.cuda.current_device())
    print("Setting up losses")

    if config['use_vgg']:
        criterionVGG = Vgg19PerceptualLoss(config['reduced_w'])
        criterionVGG.cuda()
        validationLoss = criterionVGG
    if config['use_gan']:
        use_sigmoid = config['no_lsgan']
        disc_input_channels = 3
        discriminator = MultiscaleDiscriminator(disc_input_channels,
                                                config['ndf'],
                                                config['n_layers_D'],
                                                'instance', use_sigmoid,
                                                config['num_D'], False, False)
        discriminator.apply(init_weights)
        if config['resume_ckpt_D'] is not None:
            # load weights from checkpoint
            print("Resuming discriminator from %s" % (config['resume_ckpt_D']))
            state_dict = load_weights(config['resume_ckpt_D'])
            discriminator.load_state_dict(state_dict)

        discriminator.cuda(torch.cuda.current_device())
        criterionGAN = GANLoss(use_lsgan=not config['no_lsgan'])
        criterionGAN.cuda()
        criterionFeat = nn.L1Loss().cuda()
    if config['use_l2']:
        criterionMSE = nn.MSELoss()
        criterionMSE.cuda()
        validationLoss = criterionMSE

    # initialize dataloader
    print("Setting up dataloaders...")
    train_dataloader, val_dataloader, train_sampler = setup_dataloaders(config)
    print("Done!")
    # run the training loop
    print("Initializing optimizers...")
    optimizer_G = optim.Adam(model.parameters(),
                             lr=config['learning_rate'],
                             weight_decay=config['weight_decay'])
    if config['resume_ckpt_opt_G'] is not None:
        optimizer_G_state_dict = torch.load(
            config['resume_ckpt_opt_G'],
            map_location=lambda storage, loc: storage)
        optimizer_G.load_state_dict(optimizer_G_state_dict)
    if config['use_gan']:
        optimizer_D = optim.Adam(discriminator.parameters(),
                                 lr=config['learning_rate'])
        if config['resume_ckpt_opt_D'] is not None:
            optimizer_D_state_dict = torch.load(
                config['resume_ckpt_opt_D'],
                map_location=lambda storage, loc: storage)
            optimizer_D.load_state_dict(optimizer_D_state_dict)

    print("Done!")

    if distributed:
        print("Moving model to DDP...")
        model = DDP(model)
        if config['use_gan']:
            discriminator = DDP(discriminator, delay_allreduce=True)
        print("Done!")

    tb_logger = None
    if rank == 0:
        tb_logdir = os.path.join(save_path, 'tbdir')
        if not os.path.exists(tb_logdir):
            os.makedirs(tb_logdir)
        tb_logger = SummaryWriter(tb_logdir)
        # run training
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        log_name = os.path.join(save_path, 'loss_log.txt')
        opt_name = os.path.join(save_path, 'opt.yaml')
        print(config)
        save_options(opt_name, config)
        log_handle = open(log_name, 'a')

    print("Starting training")
    cnt = start_cnt
    assert (config['use_warped'] or config['use_temporal'])

    for ep in range(start_ep, epochs):
        if train_sampler is not None:
            train_sampler.set_epoch(ep)

        for curr_batch in train_dataloader:
            optimizer_G.zero_grad()
            input_a = curr_batch['input_a'].cuda()
            target = curr_batch['target'].cuda()
            if config['use_warped'] and config['use_temporal']:
                input_a = torch.cat((input_a, input_a), 0)
                input_b = torch.cat((curr_batch['input_b'].cuda(),
                                     curr_batch['input_temporal'].cuda()), 0)
                target = torch.cat((target, target), 0)
            elif config['use_temporal']:
                input_b = curr_batch['input_temporal'].cuda()
            elif config['use_warped']:
                input_b = curr_batch['input_b'].cuda()

            output_dict = model(input_a, input_b)
            output_recon = output_dict['reconstruction']

            loss_vgg = loss_G_GAN = loss_G_feat = loss_l2 = 0
            if config['use_vgg']:
                loss_vgg = criterionVGG(output_recon,
                                        target) * config['vgg_lambda']
            if config['use_gan']:
                predicted_landmarks = output_dict['input_a_gauss_maps']
                # output_dict['reconstruction'] can be considered normalized
                loss_G_GAN, loss_D_real, loss_D_fake = apply_GAN_criterion(
                    output_recon, target, predicted_landmarks.detach(),
                    discriminator, criterionGAN)
                loss_D = (loss_D_fake + loss_D_real) * 0.5
            if config['use_l2']:
                loss_l2 = criterionMSE(output_recon,
                                       target) * config['l2_lambda']

            loss_G = loss_G_GAN + loss_G_feat + loss_vgg + loss_l2
            loss_G.backward()
            # grad_norm clipping
            if not config['no_grad_clip']:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer_G.step()
            if config['use_gan']:
                optimizer_D.zero_grad()
                loss_D.backward()
                # grad_norm clipping
                if not config['no_grad_clip']:
                    torch.nn.utils.clip_grad_norm_(discriminator.parameters(),
                                                   1.0)
                optimizer_D.step()

            if distributed:
                if config['use_vgg']:
                    loss_vgg = reduce_tensor(loss_vgg, config['world_size'])

            if rank == 0:
                if cnt % 10 == 0:
                    run_visualization(output_dict, output_recon, target,
                                      input_a, input_b, save_path, tb_logger,
                                      cnt)

                print_dict = {"learning_rate": get_learning_rate(optimizer_G)}
                if config['use_vgg']:
                    tb_logger.add_scalar('vgg.loss', loss_vgg, cnt)
                    print_dict['Loss_VGG'] = loss_vgg.data
                if config['use_gan']:
                    tb_logger.add_scalar('gan.loss', loss_G_GAN, cnt)
                    tb_logger.add_scalar('d_real.loss', loss_D_real, cnt)
                    tb_logger.add_scalar('d_fake.loss', loss_D_fake, cnt)
                    print_dict['Loss_G_GAN'] = loss_G_GAN
                    print_dict['Loss_real'] = loss_D_real.data
                    print_dict['Loss_fake'] = loss_D_fake.data
                if config['use_l2']:
                    tb_logger.add_scalar('l2.loss', loss_l2, cnt)
                    print_dict['Loss_L2'] = loss_l2.data

                log_iter(ep,
                         cnt % len(train_dataloader),
                         len(train_dataloader),
                         print_dict,
                         log_handle=log_handle)

            if loss_G != loss_G:
                print("NaN!!")
                exit(-2)

            cnt = cnt + 1
            # end of train iter loop

            if cnt % config['val_freq'] == 0 and config['val_freq'] > 0:
                val_loss = run_val(
                    model, validationLoss, val_dataloader,
                    os.path.join(save_path, 'val_%d_renders' % (ep)))

                if distributed:
                    val_loss = reduce_tensor(val_loss, config['world_size'])
                if rank == 0:
                    tb_logger.add_scalar('validation.loss', val_loss, cnt)
                    log_iter(ep,
                             cnt % len(train_dataloader),
                             len(train_dataloader), {"Loss_VGG": val_loss},
                             header="Validation loss: ",
                             log_handle=log_handle)

        if rank == 0:
            if (ep % config['save_freq'] == 0):
                fname = 'checkpoint_%d.ckpt' % (ep)
                fname = os.path.join(save_path, fname)
                print("Saving model...")
                save_weights(model, fname, distributed)
                optimizer_g_fname = os.path.join(
                    save_path, 'latest_optimizer_g_state.ckpt')
                torch.save(optimizer_G.state_dict(), optimizer_g_fname)
                if config['use_gan']:
                    fname = 'checkpoint_D_%d.ckpt' % (ep)
                    fname = os.path.join(save_path, fname)
                    save_weights(discriminator, fname, distributed)
                    optimizer_d_fname = os.path.join(
                        save_path, 'latest_optimizer_d_state.ckpt')
                    torch.save(optimizer_D.state_dict(), optimizer_d_fname)
示例#3
0
            if args.use_adasum else hvd.Sum)  #hvd.Average备选,建议使用sum,lr换算简单
    # model, optimizer = amp.initialize(model, optimizer, opt_level="O1")

    if args.evaluate:
        exit()

    for epoch in range(args.start_epoch, args.epochs):
        epoch_start = time.time()
        # train for one epoch
        train(epoch)

        # evaluate on validation set
        val_score, val_loss = validate()
        scheduler.step()

        lr_rate_1 = get_learning_rate(optimizer)
        epoch_time = time.time() - epoch_start
        if (not args.distribute) or (args.distribute and hvd.rank() == 0):
            print('Epoch[{0}] LR: {lr} Time:{time:.6f} '
                  'ValLoss {val_loss:.6f}  '
                  'Val_Score {val_score:.6f}'.format(epoch,
                                                     lr=lr_rate_1,
                                                     time=epoch_time,
                                                     val_loss=val_loss,
                                                     val_score=val_score))

        is_best = val_score > best_val_score[fold]
        best_val_score[fold] = max(val_score, best_val_score[fold])
        if is_best:
            if (not args.distribute) or (args.distribute and hvd.rank() == 0):
                print("--------current best-------:%f" % best_val_score[fold])
示例#4
0
def train(args,
          train_loader,
          val_loader,
          model,
          val_criterion,
          optimizer,
          lr_scheduler,
          epoch,
          step,
          tb,
          max_score,
          cuda=False):
    """
    Runs the training loop per epoch.
    dataloader: Data loader for train
    args: args
    net: network
    optimizer: optimizer
    cur_epoch: current epoch
    cuda: use gpu or not.
    """
    model.train()

    train_loss = AverageMeter()
    pbar = tqdm(total=len(train_loader), desc="train_model")

    for idx, data_batch in enumerate(train_loader):
        images, targets, img_names = data_batch
        if cuda:
            images, targets = images.cuda(), targets.cuda()

        inputs = {"images": images, "gts": targets}
        loss = model(inputs)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss.update(loss.item(), n=1)

        if (step + 1) % args.log_step == 0:
            tb.scalar_summary("model/loss", loss.data, step)
            tb.scalar_summary("model/lr", args.lr, step)

        pbar.set_description(desc=f"train_model| loss: {loss.item():5.3f}")

        if (step + 1) % args.val_freq == 0:
            val_scores = validation(args,
                                    val_loader,
                                    model,
                                    val_criterion,
                                    step,
                                    cuda=args.cuda)

            logger.info(
                f"| model_name {args.model_name} | step: {step} | PA: {val_scores['PA']} "
                f"| mPA: {val_scores['MPA']} | mIoU: {val_scores['MIOU']} | FWIoU: {val_scores['FWIOU']}"
            )
            logger.info(
                f"| model_name {args.model_name} | step: {step} | IOU: {val_scores['IOU']}"
            )

            tb.scalar_summary("val/PA", val_scores["PA"], step)
            tb.scalar_summary("val/mPA", val_scores["MPA"], step)
            tb.scalar_summary("val/mIoU", val_scores["MIOU"], step)
            tb.scalar_summary("val/FWIoU", val_scores["FWIOU"], step)
            for c, iou_c in enumerate(val_scores["IOU"]):
                tb.scalar_summary(f"val/IOU_{cfg.DATASET.TRAINID_TO_ID[c]}",
                                  iou_c, step)

            max_score = max(max_score, val_scores["FWIOU"])

            logger.info(f"[*] Step: {step}, max_score: {max_score}.")

            if args.lr_schedule == "reduce_lr_on_plateau":
                lr_scheduler.step(val_scores["FWIOU"])
            else:
                lr_scheduler.step()

            args.lr = get_learning_rate(optimizer)[0]

            state_dict = {
                "epoch": epoch,
                "step": step,
                "state_dict": model.state_dict(),
                "max_score": max_score,
                "optimizer": optimizer.state_dict()
            }

            # save_checkpoint(state_dict, step, is_best, args)
            save_model(state_dict,
                       step,
                       args,
                       val_scores,
                       max_save_num=5,
                       save_criterion="FWIOU")

        step += 1
        pbar.update(1)

    return train_loss.avg, step, max_score
def main(args):
    source_train_set = custom_dataset(args.train_data_path, args.train_gt_path)
    valid_train_set = valid_dataset(args.val_data_path,
                                    args.val_gt_path,
                                    data_flag='ic13')

    source_train_loader = data.DataLoader(source_train_set,
                                          batch_size=args.batch_size,
                                          shuffle=True,
                                          num_workers=args.num_workers,
                                          drop_last=True)

    valid_loader = data.DataLoader(valid_train_set,
                                   batch_size=args.batch_size,
                                   shuffle=False,
                                   num_workers=args.num_workers,
                                   drop_last=False)

    criterion = Loss().to(device)

    best_loss = 1000
    best_num = 0

    model = EAST()
    if args.pretrained_model_path:
        model.load_state_dict(torch.load(args.pretrained_model_path))

    # resume
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint['state_dict'])
        best_loss = checkpoint['best_loss']
        current_epoch_num = checkpoint['epoch']

    data_parallel = False
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
        data_parallel = True

    model.to(device)

    total_epoch = args.epochs
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    # scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[total_epoch // 3, total_epoch * 2 // 3], gamma=0.1)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                               mode='min',
                                               factor=0.1,
                                               patience=10,
                                               threshold=args.lr / 100)
    current_epoch_num = 0

    # resume
    if args.resume:
        checkpoint = torch.load(args.resume)
        scheduler.load_state_dict(checkpoint['scheduler'])

    for epoch in range(current_epoch_num, total_epoch):
        each_epoch_start = time.time()
        # scheduler.step(epoch)
        # add lr in tensorboardX
        writer.add_scalar('epoch/lr', get_learning_rate(optimizer), epoch)

        train(source_train_loader, model, criterion, optimizer, epoch)

        val_loss = eval(model, valid_loader, criterion, epoch)
        scheduler.step(val_loss)

        if val_loss < best_loss:
            best_num = epoch + 1
            best_loss = val_loss
            best_model_wts = copy.deepcopy(model.module.state_dict(
            ) if data_parallel else model.state_dict())
            # save best model

            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': best_model_wts,
                    'best_loss': best_loss,
                    'scheduler': scheduler.state_dict(),
                }, os.path.join(save_folder, "model_epoch_best.pth"))

            log.write('best model num:{}, best loss is {:.8f}'.format(
                best_num, best_loss))
            log.write('\n')

        if (epoch + 1) % int(args.save_interval) == 0:
            state_dict = model.module.state_dict(
            ) if data_parallel else model.state_dict()
            torch.save(
                {
                    'epoch': epoch + 1,
                    'state_dict': state_dict,
                    'best_loss': best_loss,
                    'scheduler': scheduler.state_dict(),
                },
                os.path.join(save_folder,
                             'model_epoch_{}.pth'.format(epoch + 1)))
            log.write('save model')
            log.write('\n')

        log.write('=' * 50)
        log.write('\n')
示例#6
0
def train_net(args):
    torch.manual_seed(7)
    np.random.seed(7)
    checkpoint = args.checkpoint
    start_epoch = 0
    best_loss = float('inf')
    writer = SummaryWriter()
    epochs_since_improvement = 0

    # Initialize / load checkpoint
    if checkpoint is None:
        model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
            pretrained=False,
            progress=True,
            num_classes=2,
            num_keypoints=14,
            pretrained_backbone=True)
        model = nn.DataParallel(model)

        if args.optimizer == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(),
                                        lr=args.lr,
                                        momentum=args.mom,
                                        weight_decay=args.weight_decay)
        else:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)

    else:
        checkpoint = torch.load(checkpoint)
        start_epoch = checkpoint['epoch'] + 1
        epochs_since_improvement = checkpoint['epochs_since_improvement']
        model = checkpoint['model']
        optimizer = checkpoint['optimizer']

    logger = get_logger()

    # Move to GPU, if available
    model = model.to(device)

    # Custom dataloaders
    train_dataset = KpDataset('train')
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=1,
                                               shuffle=True,
                                               num_workers=1)
    valid_dataset = KpDataset('valid')
    valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                               batch_size=1,
                                               shuffle=False,
                                               num_workers=1)

    # Epochs
    for epoch in range(start_epoch, args.end_epoch):
        if epochs_since_improvement == 10:
            break

        if epochs_since_improvement > 0 and epochs_since_improvement % 2 == 0:
            adjust_learning_rate(optimizer, 0.6)

        # One epoch's training
        train_loss = train(train_loader=train_loader,
                           model=model,
                           optimizer=optimizer,
                           epoch=epoch,
                           logger=logger)
        effective_lr = get_learning_rate(optimizer)
        print('Current effective learning rate: {}\n'.format(effective_lr))

        writer.add_scalar('Train_Loss', train_loss, epoch)

        # One epoch's validation
        valid_loss = valid(valid_loader=valid_loader,
                           model=model,
                           logger=logger)

        writer.add_scalar('Valid_Loss', valid_loss, epoch)

        # Check if there was an improvement
        is_best = valid_loss < best_loss
        best_loss = min(valid_loss, best_loss)
        if not is_best:
            epochs_since_improvement += 1
            print("\nEpochs since last improvement: %d\n" %
                  (epochs_since_improvement, ))
        else:
            epochs_since_improvement = 0

        # Save checkpoint
        save_checkpoint(epoch, epochs_since_improvement, model, optimizer,
                        best_loss, is_best)
示例#7
0
        optimizer.load_state_dict(torch.load("{}".format(os.path.join(load_dir, 'network.optimizer.epoch{}'.format(hp.loaded_epoch)))))
    
        dataloader = DataLoader(dataset_train, batch_sampler=sampler, num_workers=1, collate_fn=collate_fn_transformer)
        step = hp.loaded_epoch * len(dataloader)
    else:
        start_epoch = 0
        step = 1

    for epoch in range(start_epoch, hp.max_epoch):
        dataloader = DataLoader(dataset_train, batch_sampler=sampler, num_workers=4, collate_fn=collate_fn_transformer)

        #pbar = tqdm(dataloader)
        #for d in pbar:
        for d in dataloader: 
            if hp.optimizer.lower() != 'radam':
                lr = get_learning_rate(step, hp.d_model_decoder, hp.warmup_factor, hp.warmup_step)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

                text, mel, pos_text, pos_mel, text_lengths, mel_lengths, stop_token, spk_emb, f0, energy, alignment = d

                text = text.to(DEVICE, non_blocking=True)
                mel = mel.to(DEVICE, non_blocking=True)
                pos_text = pos_text.to(DEVICE, non_blocking=True)
                pos_mel = pos_mel.to(DEVICE, non_blocking=True)
                mel_lengths = mel_lengths.to(DEVICE, non_blocking=True)
                text_lengths = text_lengths.to(DEVICE, non_blocking=True)
                stop_token = stop_token.to(DEVICE, non_blocking=True)
                if hp.is_multi_speaker:
                    spk_emb = spk_emb.to(DEVICE, non_blocking=True)
                if hp.pitch_pred: