Ejemplo n.º 1
0
def main(opt, net, save_file1, save_file2, save_file3, save_file4,
         rate, epoch=100, lr=1e-4, batch_size=64, lam=1, mutil_scale_train=True):
    torch.manual_seed(100)
    interval = epoch // 10
    torch.backends.cudnn.benchmark = True
    print('Creating model...')
    start_epoch = 0
    if save_file1 == './res5/m1_best1.pth':
        net, start_epoch, = load_model(net, './res50/m1_last.pth', opt.device)
    net = net.to(opt.device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=1e-3)

    trainer = Trainer(opt, net, optimizer)

    print('Setting up data...')

    def adjust_lr(optimizer, p):
        for params in optimizer.param_groups:
            params['lr'] *= p

    print('Starting training...')

    total_best = 1e10
    hm_best = 1e10
    corner_best = 1e10
    ll = start_epoch // 50

    if not mutil_scale_train:
        train_loader = torch.utils.data.DataLoader(
            TrainDataset(opt, 512, True),
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            drop_last=True,
            num_workers=4
        )
        val_loader = torch.utils.data.DataLoader(
            TrainDataset(opt, 512, False),
            batch_size=batch_size,
            shuffle=False,
            pin_memory=True,
            num_workers=4
        )
    for _ in range(ll):
        adjust_lr(optimizer, rate)
    for epoch in range(start_epoch + 1, epoch + 1):
        resolutions = np.arange(416, 577, 32)
        if mutil_scale_train:
            if epoch % interval == 1:
                reso = np.random.choice(resolutions)
                train_loader = torch.utils.data.DataLoader(
                    TrainDataset(opt, reso, True),
                    batch_size=batch_size,
                    shuffle=True,
                    pin_memory=True,
                    drop_last=True,
                    num_workers=4
                )
                val_loader = torch.utils.data.DataLoader(
                    TrainDataset(opt, reso, False),
                    batch_size=batch_size,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=4
                )

        if epoch % 50 == 1:
            adjust_lr(optimizer, rate)
        log_dict_train, _, _, _ = trainer.train(lam, epoch, train_loader)
        if epoch % 1 == 0:
            with torch.no_grad():
                log_dict_val, preds, hmloss, closs = trainer.val(lam, epoch, val_loader)
            if log_dict_val['total_loss'] < total_best:
                total_best = log_dict_val['total_loss']
                save_model(os.path.join(save_file1),
                           epoch, net)
            else:
                save_model(os.path.join(save_file4),
                           epoch, net)

            if log_dict_val['corner_loss'] < corner_best:
                corner_best = log_dict_val['corner_loss']
                save_model(os.path.join(save_file2),
                           epoch, net)
            if log_dict_val['heatmap_loss'] < hm_best:
                hm_best = log_dict_val['heatmap_loss']
                save_model(os.path.join(save_file3),
                           epoch, net)

    return hm_best, corner_best
Ejemplo n.º 2
0
            print('Drop LR to', lr)
            for param_group in opt.param_groups:
                param_group['lr'] = lr

        if not written and hvd.rank() == 0:
            os.system("nvidia-smi")
            written = True

        if hvd.rank() == 0:
            logger.write('epoch: {} |'.format(epoch))
            for k, v in log_dict_train.items():
                logger.write('{} {:8f} | '.format(k, v))
            logger.write('\n')

            with torch.no_grad():
                log_dict_val, preds = trainer.val(epoch, val_loader, rank=hvd.rank())
                for k, v in log_dict_val.items():
                    logger.write('{} {:8f} | '.format(k, v))
                logger.write('\n')

            if epoch in args.save_point:
                save_model(os.path.join(args.weights_dir, 'model_{}.pth'.format(epoch)),
                           epoch, blaze_palm, opt)