Пример #1
0
def main(args):
    
    # build train and val set
    train_dir = args.train_dir
    val_dir = args.val_dir

    config = Config(args.config)
    cudnn.benchmark = True

    # train
    train_loader = torch.utils.data.DataLoader(
        lsp_lspet_data.LSP_Data('lspet', train_dir, 8,
                Mytransforms.Compose([Mytransforms.RandomResized(),
                Mytransforms.RandomRotate(40),
                Mytransforms.RandomCrop(368),
                Mytransforms.RandomHorizontalFlip(),
            ])),
            batch_size=config.batch_size, shuffle=True,
            num_workers=config.workers, pin_memory=True)
    # val
    if args.val_dir is not None and config.test_interval != 0:
        # val
        val_loader = torch.utils.data.DataLoader(
            lsp_lspet_data.LSP_Data('lsp', val_dir, 8,
                              Mytransforms.Compose([Mytransforms.TestResized(368),
                                                    ])),
            batch_size=config.batch_size, shuffle=False,
            num_workers=config.workers, pin_memory=True)
    
    # build model
    model = MSBR(config=config, args=args, k=14, stages=config.stages)

    model.build_nets()


    return model, train_loader, val_loader
Пример #2
0
def train_val(model, args):

    train_dir = args.train_dir
    val_dir = args.val_dir

    config = Config(args.config)
    cudnn.benchmark = True

    # train
    train_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
        'lspet', train_dir, 8,
        Mytransforms.Compose([
            Mytransforms.RandomResized(),
            Mytransforms.RandomRotate(40),
            Mytransforms.RandomCrop(368),
            Mytransforms.RandomHorizontalFlip(),
        ])),
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True)
    # val
    if args.val_dir is not None and config.test_interval != 0:
        # val
        val_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
            'lsp', val_dir, 8,
            Mytransforms.Compose([
                Mytransforms.TestResized(368),
            ])),
                                                 batch_size=config.batch_size,
                                                 shuffle=True,
                                                 num_workers=config.workers,
                                                 pin_memory=True)

    if args.gpu[0] < 0:
        criterion = nn.MSELoss()
    else:
        criterion = nn.MSELoss().cuda()

    params, multiple = get_parameters(model, config, True)
    # params, multiple = get_parameters(model, config, False)

    optimizer = torch.optim.SGD(params,
                                config.base_lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_list = [AverageMeter() for i in range(6)]
    end = time.time()
    iters = config.start_iters
    best_model = config.best_model

    heat_weight = 46 * 46 * 15 / 1.0

    losstracker1 = []

    losstracker2 = []
    losstracker3 = []
    losstracker4 = []
    losstracker5 = []
    losstracker6 = []
    while iters < config.max_iter:

        for i, (input, heatmap, centermap) in enumerate(train_loader):

            learning_rate = adjust_learning_rate(
                optimizer,
                iters,
                config.base_lr,
                policy=config.lr_policy,
                policy_parameter=config.policy_parameter,
                multiple=multiple)
            data_time.update(time.time() - end)

            if args.gpu[0] >= 0:
                heatmap = heatmap.cuda(async=True)
                centermap = centermap.cuda(async=True)

            input_var = torch.autograd.Variable(input)
            heatmap_var = torch.autograd.Variable(heatmap)
            centermap_var = torch.autograd.Variable(centermap)

            heat1, heat2, heat3, heat4, heat5, heat6 = model(
                input_var, centermap_var)

            loss1 = criterion(heat1, heatmap_var) * heat_weight
            loss2 = criterion(heat2, heatmap_var) * heat_weight
            loss3 = criterion(heat3, heatmap_var) * heat_weight
            loss4 = criterion(heat4, heatmap_var) * heat_weight
            loss5 = criterion(heat5, heatmap_var) * heat_weight
            loss6 = criterion(heat6, heatmap_var) * heat_weight

            loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            #print(input.size(0).item())
            losses.update(loss.item(), input.size(0))
            for cnt, l in enumerate([loss1, loss2, loss3, loss4, loss5,
                                     loss6]):
                losses_list[cnt].update(l.item(), input.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            iters += 1
            if iters % config.display == 0:
                print(
                    'Train Iteration: {0}\t'
                    'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {2}\n'
                    'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                        iters,
                        config.display,
                        learning_rate,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses))
                for cnt in range(0, 6):
                    print(
                        'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'.
                        format(cnt + 1, loss1=losses_list[cnt]))

                print(
                    time.strftime(
                        '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                        time.localtime()))

                batch_time.reset()
                data_time.reset()
                losses.reset()
                for cnt in range(6):
                    losses_list[cnt].reset()

            save_checkpoint({
                'iter': iters,
                'state_dict': model.state_dict(),
            }, 0, args.model_name)

            # val
            if args.val_dir is not None and config.test_interval != 0 and iters % config.test_interval == 0:

                model.eval()
                for j, (input, heatmap, centermap) in enumerate(val_loader):
                    if args.cuda[0] >= 0:
                        heatmap = heatmap.cuda(async=True)
                        centermap = centermap.cuda(async=True)

                    input_var = torch.autograd.Variable(input)
                    heatmap_var = torch.autograd.Variable(heatmap)
                    centermap_var = torch.autograd.Variable(centermap)

                    heat1, heat2, heat3, heat4, heat5, heat6 = model(
                        input_var, centermap_var)

                    loss1 = criterion(heat1, heatmap_var) * heat_weight
                    loss2 = criterion(heat2, heatmap_var) * heat_weight
                    loss3 = criterion(heat3, heatmap_var) * heat_weight
                    loss4 = criterion(heat4, heatmap_var) * heat_weight
                    loss5 = criterion(heat5, heatmap_var) * heat_weight
                    loss6 = criterion(heat6, heatmap_var) * heat_weight

                    loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
                    losses.update(loss.data[0], input.size(0))
                    for cnt, l in enumerate(
                        [loss1, loss2, loss3, loss4, loss5, loss6]):
                        losses_list[cnt].update(l.data[0], input.size(0))

                    batch_time.update(time.time() - end)
                    end = time.time()
                    is_best = losses.avg < best_model
                    best_model = min(best_model, losses.avg)
                    save_checkpoint(
                        {
                            'iter': iters,
                            'state_dict': model.state_dict(),
                        }, is_best, args.model_name)

                    if j % config.display == 0:
                        print(
                            'Test Iteration: {0}\t'
                            'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                            'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                            'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.
                            format(j,
                                   config.display,
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses))
                        for cnt in range(0, 6):
                            print(
                                'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'
                                .format(cnt + 1, loss1=losses_list[cnt]))

                        print(
                            time.strftime(
                                '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                                time.localtime()))
                        batch_time.reset()
                        losses.reset()
                        for cnt in range(6):
                            losses_list[cnt].reset()

                        losstracker1.append(loss1)
                        losstracker2.append(loss2)
                        losstracker3.append(loss3)
                        losstracker4.append(loss4)
                        losstracker5.append(loss5)
                        losstracker6.append(loss6)
                model.train()

    np.save('loss1', np.asarray(losstracker1))
    np.save('loss2', np.asarray(losstracker2))
    np.save('loss3', np.asarray(losstracker3))
    np.save('loss4', np.asarray(losstracker4))
    np.save('loss5', np.asarray(losstracker5))
    np.save('loss6', np.asarray(losstracker6))
Пример #3
0
def train_val(model, args):

    train_dir = args.train_dir
    val_dir = args.val_dir

    config = Config(args.config)
    cudnn.benchmark = True

    #lspet dataset contains 10000 images, lsp dataset contains 2000 images.

    # train
    train_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
        'lspet', train_dir, 8,
        Mytransforms.Compose([
            Mytransforms.RandomResized(),
            Mytransforms.RandomRotate(40),
            Mytransforms.RandomCrop(368),
            Mytransforms.RandomHorizontalFlip(),
        ])),
                                               batch_size=config.batch_size,
                                               shuffle=True,
                                               num_workers=config.workers,
                                               pin_memory=True)

    # val
    if args.val_dir is not None and config.test_interval != 0:
        # val
        val_loader = torch.utils.data.DataLoader(lsp_lspet_data.LSP_Data(
            'lsp', val_dir, 8,
            Mytransforms.Compose([
                Mytransforms.TestResized(368),
            ])),
                                                 batch_size=config.batch_size,
                                                 shuffle=True,
                                                 num_workers=config.workers,
                                                 pin_memory=True)

    criterion = nn.MSELoss().cuda()

    params, multiple = get_parameters(model, config, False)

    optimizer = torch.optim.SGD(params,
                                config.base_lr,
                                momentum=config.momentum,
                                weight_decay=config.weight_decay)

    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    losses_list = [AverageMeter() for i in range(6)]
    end = time.time()
    iters = config.start_iters
    best_model = config.best_model

    heat_weight = 46 * 46 * 15 / 1.0

    while iters < config.max_iter:
        #train_loader가 한번 불러오면 i는 1증가, input은 16개씩 가져옴
        for i, (input, heatmap, centermap,
                img_path) in enumerate(train_loader):

            learning_rate = adjust_learning_rate(
                optimizer,
                iters,
                config.base_lr,
                policy=config.lr_policy,
                policy_parameter=config.policy_parameter,
                multiple=multiple)
            data_time.update(time.time() - end)

            heatmap = heatmap.cuda(async=True)
            #print(heatmap)
            #sys.exit(1)
            centermap = centermap.cuda(async=True)

            input_var = torch.autograd.Variable(input)
            heatmap_var = torch.autograd.Variable(heatmap)
            centermap_var = torch.autograd.Variable(centermap)

            heat1, heat2, heat3, heat4, heat5, heat6 = model(
                input_var, centermap_var)

            loss1 = criterion(heat1, heatmap_var) * heat_weight
            loss2 = criterion(heat2, heatmap_var) * heat_weight
            loss3 = criterion(heat3, heatmap_var) * heat_weight
            loss4 = criterion(heat4, heatmap_var) * heat_weight
            loss5 = criterion(heat5, heatmap_var) * heat_weight
            loss6 = criterion(heat6, heatmap_var) * heat_weight

            loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
            losses.update(loss.data[0], input.size(0))
            for cnt, l in enumerate([loss1, loss2, loss3, loss4, loss5,
                                     loss6]):
                losses_list[cnt].update(l.data[0], input.size(0))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_time.update(time.time() - end)
            end = time.time()

            iters += 1
            #print(i,'\n')
            if iters % config.display == 0:
                print(
                    'Train Iteration: {0}\t'
                    'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                    'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                    'Learning rate = {2}\n'
                    'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.format(
                        iters,
                        config.display,
                        learning_rate,
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses))
                for cnt in range(0, 6):
                    print(
                        'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'.
                        format(cnt + 1, loss1=losses_list[cnt]))

                print(
                    time.strftime(
                        '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                        time.localtime()))
                #############    image write  ##################
                for cnt in range(config.batch_size):
                    kpts = get_kpts(heat6[cnt], img_h=368.0, img_w=368.0)
                    draw_paint(img_path[cnt], kpts, i, cnt)
                #######################################################
                batch_time.reset()
                data_time.reset()
                losses.reset()
                for cnt in range(6):
                    losses_list[cnt].reset()

            save_checkpoint({
                'iter': iters,
                'state_dict': model.state_dict(),
            }, 0, args.model_name)

            # val
            if args.val_dir is not None and config.test_interval != 0 and iters % config.test_interval == 0:

                model.eval()
                for j, (input, heatmap, centermap) in enumerate(val_loader):
                    heatmap = heatmap.cuda(async=True)
                    centermap = centermap.cuda(async=True)

                    input_var = torch.autograd.Variable(input)
                    heatmap_var = torch.autograd.Variable(heatmap)
                    centermap_var = torch.autograd.Variable(centermap)

                    heat1, heat2, heat3, heat4, heat5, heat6 = model(
                        input_var, centermap_var)

                    loss1 = criterion(heat1, heatmap_var) * heat_weight
                    loss2 = criterion(heat2, heatmap_var) * heat_weight
                    loss3 = criterion(heat3, heatmap_var) * heat_weight
                    loss4 = criterion(heat4, heatmap_var) * heat_weight
                    loss5 = criterion(heat5, heatmap_var) * heat_weight
                    loss6 = criterion(heat6, heatmap_var) * heat_weight

                    loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
                    losses.update(loss.data[0], input.size(0))
                    for cnt, l in enumerate(
                        [loss1, loss2, loss3, loss4, loss5, loss6]):
                        losses_list[cnt].update(l.data[0], input.size(0))

                    batch_time.update(time.time() - end)
                    end = time.time()
                    is_best = losses.avg < best_model
                    best_model = min(best_model, losses.avg)
                    save_checkpoint(
                        {
                            'iter': iters,
                            'state_dict': model.state_dict(),
                        }, is_best, args.model_name)

                    if j % config.display == 0:
                        print(
                            'Test Iteration: {0}\t'
                            'Time {batch_time.sum:.3f}s / {1}iters, ({batch_time.avg:.3f})\t'
                            'Data load {data_time.sum:.3f}s / {1}iters, ({data_time.avg:3f})\n'
                            'Loss = {loss.val:.8f} (ave = {loss.avg:.8f})\n'.
                            format(j,
                                   config.display,
                                   batch_time=batch_time,
                                   data_time=data_time,
                                   loss=losses))
                        for cnt in range(0, 6):
                            print(
                                'Loss{0} = {loss1.val:.8f} (ave = {loss1.avg:.8f})\t'
                                .format(cnt + 1, loss1=losses_list[cnt]))

                        print(
                            time.strftime(
                                '%Y-%m-%d %H:%M:%S -----------------------------------------------------------------------------------------------------------------\n',
                                time.localtime()))
                        batch_time.reset()
                        losses.reset()
                        for cnt in range(6):
                            losses_list[cnt].reset()

                model.train()