Beispiel #1
0
    def resume_from_checkpoint(self):
        if isfile(self.resume['p2v2c']):
            print("=> loading p2v2c checkpoint '{}'".format(self.resume['p2v2c']))
            if self.is_cuda:
                checkpoint_p2v2c = torch.load(self.resume['p2v2c'])
                self.pix2vox.load_state_dict(checkpoint_p2v2c['state_dict_p2v'])
                self.vox2coord.load_state_dict(checkpoint_p2v2c['state_dict_v2c'])
            else:
                # load onto the CPU
                checkpoint_p2v2c = torch.load(self.resume['p2v2c'], map_location=lambda storage, loc: storage)
                # remove module. from dict keys
                checkpoint_p2v = {k[7:]: v for k, v in checkpoint_p2v2c['state_dict_p2v'].items()}
                checkpoint_v2c = {k[7:]: v for k, v in checkpoint_p2v2c['state_dict_v2c'].items()}
                self.pix2vox.load_state_dict(checkpoint_p2v)
                self.vox2coord.load_state_dict(checkpoint_v2c)
            print("=> loaded checkpoint '{}'".format(self.resume['p2v2c']))
        else:
            # load pixel2voxel model
            if isfile(self.resume['p2v']):
                print("=> loading p2v checkpoint '{}'".format(self.resume['p2v']))
                checkpoint_p2v = torch.load(self.resume['p2v'])
                self.pix2vox.load_state_dict(checkpoint_p2v['state_dict_p2v'])
                print("=> loaded checkpoint '{}'".format(self.resume['p2v']))

            # load vox2coord model
            if isfile(self.resume['v2c']):
                print("=> loading checkpoint '{}'".format(self.resume['v2c']))
                checkpoint_v2c = torch.load(self.resume['v2c'])
                self.vox2coord.load_state_dict(checkpoint_v2c['state_dict_v2c'])
                print("=> loaded checkpoint '{}'".format(self.resume['v2c']))
def trans_anno(ori_file, target_file, is_val):
    file_exist = False
    no_ori = False
    train_anno = os.path.join(anno_root, target_file)
    if isfile(train_anno):
        file_exist = True
    ori_anno = os.path.join(anno_root, ori_file)
    if isfile(ori_anno) == False:
        no_ori = True
    if file_exist == False and no_ori == False:
        coco_kps = COCO(ori_anno)
        coco_ids = coco_kps.getImgIds()
        catIds = coco_kps.getCatIds(catNms=['person'])
        train_data = []
        print('transforming annotations...')
        for img_id in tqdm(coco_ids):
            img = coco_kps.loadImgs(img_id)[0]
            annIds = coco_kps.getAnnIds(imgIds=img['id'], catIds=catIds)
            anns = coco_kps.loadAnns(annIds)
            for ann in anns:
                if ann['num_keypoints'] == 0:
                    continue
                single_data = {}
                keypoints = ann['keypoints']
                bbox = ann['bbox']
                num_keypoints = ann['num_keypoints']
                file_name = img['file_name']
                unit = {}
                unit['num_keypoints'] = num_keypoints
                unit['keypoints'] = keypoints
                x1, y1, width, height = bbox
                x2 = x1 + width
                y2 = y1 + height
                unit['GT_bbox'] = [int(x1), int(y1), int(x2), int(y2)]
                single_data['unit'] = unit
                imgInfo = {}
                imgInfo['imgID'] = img_id
                imgInfo['img_paths'] = file_name
                single_data['imgInfo'] = imgInfo
                if is_val == False:
                    for i in range(4):
                        tmp = single_data.copy()
                        tmp['operation'] = i
                        train_data.append(tmp)
                else:
                    single_data['score'] = 1
                    train_data.append(single_data)
        print('saving transformed annotation...')
        with open(train_anno, 'w') as wf:
            json.dump(train_data, wf)
        print('done')
    if no_ori:
        print('''WARNING! There is no annotation file find at {}. 
			Make sure you have put annotation files into the right folder.'''
              .format(ori_anno))
Beispiel #3
0
def anno_transform(original, target, is_val):
    target_path = os.path.join(annot_root, target)
    original_path = os.path.join(annot_root, original)

    if isfile(target_path) == True:
        print('The file already exists')
        return

    if isfile(original_path) == False:
        print('No original files')
        return

    print('The annotation is being transformed')

    ori_anno = COCO(original_path)
    img_ids = ori_anno.getImgIds()
    cat_ids = ori_anno.getCatIds(catNms=['person'])
    train_data = []
    for img_id in tqdm(img_ids):
        img = ori_anno.loadImgs(img_id)[0]
        anno_id = ori_anno.getAnnIds(imgIds=img['id'], catIds=cat_ids)
        anno_files = ori_anno.loadAnns(anno_id)

        for anno in anno_files:
            if anno['num_keypoints'] == 0:
                continue
            datum = {}
            inp = {}
            inp['num_keypoints'] = anno['num_keypoints']
            inp['keypoints'] = anno['keypoints']
            x, y, w, h = anno['bbox']
            inp['bbox'] = [int(x), int(y), int(x + w), int(y + h)]

            img_info = {}
            img_info['img_id'] = img_id
            img_info['img_path'] = img['file_name']

            datum['img_info'] = img_info
            datum['input'] = inp
            datum['area'] = anno['area']

            if is_val == False:
                for i in range(4):
                    temp = datum.copy()
                    temp['operation'] = i
                    train_data.append(temp)
            else:
                datum['score'] = 1
                train_data.append(datum)

    with open(target_path, 'w') as f:
        json.dump(train_data, f)
    print('done')
Beispiel #4
0
def main(args):
    os.environ['CUDA_VISIBLE_DEVICES'] = '3'
    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    model = network.__dict__[cfg.model](cfg.output_shape, cfg.num_class, pretrained = False)
    model = torch.nn.DataParallel(model).cuda()

    # define loss function (criterion) and optimizer
    criterion1 = torch.nn.MSELoss().cuda() # for Global loss
    criterion2 = torch.nn.MSELoss(reduce=False).cuda() # for refine loss
    optimizer = torch.optim.Adam(model.parameters(),
                                lr = cfg.lr,
                                weight_decay=cfg.weight_decay)
    
    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            pretrained_dict = checkpoint['state_dict']
            model.load_state_dict(pretrained_dict)
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:        
        logger = Logger(join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'LR', 'Train Loss'])

    cudnn.benchmark = True
    print('    Total params: %.2fMB' % (sum(p.numel() for p in model.parameters())/(1024*1024)*4))

    train_loader = torch.utils.data.DataLoader(
        MscocoMulti(cfg),
        batch_size=cfg.batch_size*args.num_gpus, shuffle=True,
        num_workers=args.workers, pin_memory=True) 

    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch, cfg.lr_gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr)) 

        # train for one epoch
        train_loss = train(train_loader, model, [criterion1, criterion2], optimizer)
        print('train_loss: ',train_loss)

        # append logger file
        logger.append([epoch + 1, lr, train_loss])

        save_model({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer' : optimizer.state_dict(),
        }, checkpoint=args.checkpoint)

    logger.close()
Beispiel #5
0
def anno_transform(gt, det, target):
    target_path = os.path.join(annot_root, target)
    det_path = os.path.join(annot_root, det)
    gt_path = os.path.join(annot_root, gt)

    if isfile(target_path) == True:
        print('The file already exists')
        return

    if isfile(det_path) == False:
        print('No original files')
        return

    print('The annotation is being transformed')

    eval_gt = COCO(gt_path)
    with open(det_path) as f:
        dets = json.load(f)

    dets = [i for i in dets if i['image_id'] in eval_gt.imgs]
    dets = [i for i in dets if i['category_id'] == 1]
    dets.sort(key=lambda x: (x['image_id'], x['score']), reverse=True)

    det_anno = []

    for anno in tqdm(dets):
        datum = {}
        inp = {}
        img = eval_gt.loadImgs(anno['image_id'])[0]

        x, y, w, h = anno['bbox']
        inp['bbox'] = [int(x), int(y), int(x + w), int(y + h)]

        img_info = {}
        img_info['img_id'] = img['id']
        img_info['img_path'] = img['file_name']

        datum['img_info'] = img_info
        datum['input'] = inp
        datum['score'] = anno['score']
        det_anno.append(datum)

    with open(target_path, 'w') as f:
        json.dump(det_anno, f)

    print('done')
Beispiel #6
0
def get_config(filepath=''):
    if not isfile(filepath):
        assert False
    
    with open(filepath) as f:
        data = yaml.load(f, Loader=yaml.FullLoader)
        cfg = EasyDict(data)
    return cfg
    
Beispiel #7
0
def main():
    args = parse_args()
    update_config(cfg_hrnet, args)

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    #print('networks.'+ cfg_hrnet.MODEL.NAME+'.get_pose_net')
    model = eval('models.' + cfg_hrnet.MODEL.NAME + '.get_pose_net')(
        cfg_hrnet, is_train=True)
    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    # show net
    args.channels = 3
    args.height = cfg.data_shape[0]
    args.width = cfg.data_shape[1]
    #net_vision(model, args)

    # define loss function (criterion) and optimizer
    criterion = torch.nn.MSELoss(reduction='mean').cuda()

    #torch.optim.Adam
    optimizer = AdaBound(model.parameters(),
                         lr=cfg.lr,
                         weight_decay=cfg.weight_decay)

    if args.resume:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            pretrained_dict = checkpoint['state_dict']
            model.load_state_dict(pretrained_dict)
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'LR', 'Train Loss'])

    cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    print('    Total params: %.2fMB' %
          (sum(p.numel() for p in model.parameters()) / (1024 * 1024) * 4))

    train_loader = torch.utils.data.DataLoader(
        #MscocoMulti(cfg),
        KPloader(cfg),
        batch_size=cfg.batch_size * len(args.gpus))
    #, shuffle=True,
    #num_workers=args.workers, pin_memory=True)

    #for i, (img, targets, valid) in enumerate(train_loader):
    #    print(i, img, targets, valid)

    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch,
                                  cfg.lr_gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss = train(train_loader, model, criterion, optimizer)
        print('train_loss: ', train_loss)

        # append logger file
        logger.append([epoch + 1, lr, train_loss])

        save_model(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

    logger.close()
Beispiel #8
0
def main(args):
    # import pdb; pdb.set_trace()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = torch.device("cpu")
    print(device)

    writer = SummaryWriter(cfg.tensorboard_path)
    # create checkpoint dir
    counter = 0
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    model = network.__dict__[cfg.model](cfg.output_shape,
                                        cfg.num_class,
                                        pretrained=True)

    model = torch.nn.DataParallel(model).to(device)
    # model = model.to(device)

    # define loss function (criterion) and optimizer
    criterion_bce = torch.nn.BCELoss().to(device)
    criterion_abs = torch.nn.L1Loss().to(device)
    # criterion_abs = offset_loss().to(device)
    # criterion1 = torch.nn.MSELoss().to(device) # for Global loss
    # criterion2 = torch.nn.MSELoss(reduce=False).to(device) # for refine loss
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.lr,
                                 weight_decay=cfg.weight_decay)

    if args.resume:
        print(args.resume)
        checkpoint_file_resume = os.path.join(args.checkpoint,
                                              args.resume + '.pth.tar')
        if isfile(checkpoint_file_resume):
            print("=> loading checkpoint '{}'".format(checkpoint_file_resume))
            checkpoint = torch.load(checkpoint_file_resume)
            pretrained_dict = checkpoint['state_dict']
            model.load_state_dict(pretrained_dict)
            args.start_epoch = checkpoint['epoch']
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                checkpoint_file_resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(
                checkpoint_file_resume))
    else:
        logger = Logger(join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'LR', 'Train Loss'])

    cudnn.benchmark = True
    print('    Total params: %.2fMB' %
          (sum(p.numel() for p in model.parameters()) / (1024 * 1024) * 4))

    train_loader = torch.utils.data.DataLoader(MscocoMulti_double_only(cfg),
                                               batch_size=cfg.batch_size *
                                               args.num_gpus,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True)

    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch,
                                  cfg.lr_gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss, counter = train(train_loader, model,
                                    [criterion_abs, criterion_bce], writer,
                                    counter, optimizer, device)
        print('train_loss: ', train_loss)

        # append logger file
        logger.append([epoch + 1, lr, train_loss])

        save_model(
            {
                'epoch': epoch + 1,
                'info': cfg.info,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            },
            checkpoint=args.checkpoint)

    writer.export_scalars_to_json("./test.json")
    writer.close()

    logger.close()
Beispiel #9
0
def main(args):
    """
    Main training loop for training a stacked hourglass model on MPII dataset.
    :param args: Command line arguments.
    """
    global best_acc

    # create checkpoint dir
    if not isdir(args.checkpoint_dir):
        mkdir_p(args.checkpoint_dir)

    # create model
    print("==> creating model '{}', stacks={}, blocks={}".format(
        args.arch, args.stacks, args.blocks))
    model = HourglassNet(num_stacks=args.stacks,
                         num_blocks=args.blocks,
                         num_classes=args.num_classes,
                         batch_norm_momentum=args.batch_norm_momentum,
                         use_layer_norm=args.use_layer_norm,
                         width=256,
                         height=256)
    joint_visibility_model = JointVisibilityNet(hourglass_stacks=args.stacks)

    # scale weights
    if args.scale_weight_factor != 1.0:
        model.scale_weights_(args.scale_weight_factor)

    # setup horovod and model for parallel execution
    if args.use_horovod:
        hvd.init()
        torch.cuda.set_device(hvd.local_rank())
        args.lr *= hvd.size()
        model.cuda()
    else:
        model = model.cuda()
        if args.predict_joint_visibility:
            joint_visibility_model = joint_visibility_model.cuda()

    # define loss function (criterion) and optimizer
    criterion = torch.nn.MSELoss(size_average=True).cuda()
    joint_visibility_criterion = None if not args.predict_joint_visibility else torch.nn.BCEWithLogitsLoss(
    )
    params = [{'params': model.parameters(), 'lr': args.lr}]
    if args.predict_joint_visibility:
        params.append({
            'params': joint_visibility_model.parameters(),
            'lr': args.lr
        })
    params = model.parameters()
    if not args.use_amsprop:
        optimizer = torch.optim.RMSprop(params,
                                        lr=args.lr,
                                        momentum=args.momentum,
                                        weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay,
                                     amsgrad=True)
    if args.use_horovod:
        optimizer = hvd.DistributedOptimizer(
            optimizer, named_parameters=model.named_parameters())

    # Create a tensorboard writer
    writer = SummaryWriter(log_dir="%s/hourglass_mpii_%s_tb_log" %
                           (args.tb_dir, args.exp))

    # optionally resume from a checkpoint
    title = 'mpii-' + args.arch
    if args.load:
        if isfile(args.load):
            print("=> loading checkpoint '{}'".format(args.load))
            checkpoint = torch.load(args.load)

            # remove old usage of data parallel (used to be wrapped around model) # TODO: remove this when no old models used this
            state_dict = {}
            for key in checkpoint['state_dict']:
                new_key = key[len("module."):] if key.startswith(
                    "module.") else key
                state_dict[new_key] = checkpoint['state_dict'][key]

            # restore state
            args.start_epoch = checkpoint['epoch']
            best_acc = checkpoint['best_acc']
            model.load_state_dict(state_dict)
            if args.predict_joint_visibility:
                joint_visibility_model.load_state_dict(
                    checkpoint['joint_visibility_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])

            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.load, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint_dir, 'log.txt'),
                            title=title,
                            resume=True)
        else:
            raise Exception("=> no checkpoint found at '{}'".format(args.load))
    else:
        logger = Logger(join(args.checkpoint_dir, 'log.txt'), title=title)
        logger.set_names(
            ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])

    cudnn.benchmark = True
    print('    Total params: %.2fM' %
          (sum(p.numel() for p in model.parameters()) / 1000000.0))

    # Data loading code
    train_dataset, train_loader, val_loader = _make_torch_data_loaders(args)

    if args.evaluate:
        print('\nEvaluation only')
        loss, acc, predictions = validate(val_loader, model, criterion,
                                          args.num_classes, args.debug,
                                          args.flip)
        save_pred(predictions, checkpoint=args.checkpoint_dir)
        return

    lr = args.lr
    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, lr, args.schedule,
                                  args.gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # decay sigma
        if args.sigma_decay > 0:
            train_loader.dataset.sigma *= args.sigma_decay
            val_loader.dataset.sigma *= args.sigma_decay

        # train for one epoch
        train_loss, train_acc, joint_visibility_loss, joint_visibility_acc = train(
            train_loader,
            model=model,
            joint_visibility_model=joint_visibility_model,
            criterion=criterion,
            num_joints=args.num_classes,
            joint_visibility_criterion=joint_visibility_criterion,
            optimizer=optimizer,
            epoch=epoch,
            writer=writer,
            lr=lr,
            debug=args.debug,
            flip=args.flip,
            remove_intermediate_supervision=args.
            remove_intermediate_supervision,
            tb_freq=args.tb_log_freq,
            no_grad_clipping=args.no_grad_clipping,
            grad_clip=args.grad_clip,
            use_horovod=args.use_horovod,
            predict_joint_visibility=args.predict_joint_visibility,
            predict_joint_loss_coeff=args.joint_visibility_loss_coeff)

        # evaluate on validation set
        valid_loss, valid_acc_PCK, valid_acc_PCKh, valid_acc_PCKh_per_joint, valid_joint_visibility_loss, valid_joint_visibility_acc, predictions = validate(
            val_loader, model, joint_visibility_model, criterion,
            joint_visibility_criterion, args.num_classes, args.debug,
            args.flip, args.use_horovod, args.use_train_mode_to_eval,
            args.predict_joint_visibility)

        # append logger file, and write to tensorboard summaries
        writer.add_scalars('data/epoch/losses_wrt_epochs', {
            'train_loss': train_loss,
            'test_lost': valid_loss
        }, epoch)
        writer.add_scalar('data/epoch/train_accuracy_PCK', train_acc, epoch)
        writer.add_scalar('data/epoch/test_accuracy_PCK', valid_acc_PCK, epoch)
        writer.add_scalar('data/epoch/test_accuracy_PCKh', valid_acc_PCKh,
                          epoch)
        for key in valid_acc_PCKh_per_joint:
            writer.add_scalar(
                'per_joint_data/epoch/test_accuracy_PCKh_%s' % key,
                valid_acc_PCKh_per_joint[key], epoch)
        logger.append(
            [epoch + 1, lr, train_loss, valid_loss, train_acc, valid_acc_PCK])
        if args.predict_joint_visibility:
            writer.add_scalars(
                'joint_visibility/epoch/loss', {
                    'train': joint_visibility_loss,
                    'test_lost': valid_joint_visibility_loss
                }, epoch)
            writer.add_scalars(
                'joint_visibility/epoch/acc', {
                    'train': joint_visibility_acc,
                    'test_lost': valid_joint_visibility_acc
                }, epoch)

        # remember best acc and save checkpoint
        model_specific_checkpoint_dir = "%s/hourglass_mpii_%s" % (
            args.checkpoint_dir, args.exp)
        if not isdir(model_specific_checkpoint_dir):
            mkdir_p(model_specific_checkpoint_dir)

        is_best = valid_acc_PCK > best_acc
        best_acc = max(valid_acc_PCK, best_acc)
        mean, stddev = train_dataset.get_mean_stddev()
        checkpoint = {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_acc': best_acc,
            'optimizer': optimizer.state_dict(),
            'mean': mean,
            'stddev': stddev,
        }
        if args.predict_joint_visibility:
            checkpoint[
                'joint_visibility_state_dict'] = joint_visibility_model.state_dict(
                )
        save_checkpoint(checkpoint,
                        predictions,
                        is_best,
                        checkpoint=model_specific_checkpoint_dir)

    logger.close()
Beispiel #10
0
def main():
    args = parse_args()

    # create checkpoint dir
    if not isdir(args.checkpoint):
        mkdir_p(args.checkpoint)

    # create model
    model = network.__dict__[cfg.model](cfg.channel_settings,
                                        cfg.output_shape,
                                        cfg.num_class,
                                        pretrained=True)

    # show net
    args.channels = 3
    args.height = cfg.data_shape[0]
    args.width = cfg.data_shape[1]
    #net_vision(model, args)

    if 1:
        if isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            model.load_state_dict(checkpoint['state_dict'])
            args.start_epoch = checkpoint['epoch']
            lr = checkpoint['lr']
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
            logger = Logger(join(args.checkpoint, 'log.txt'), resume=True)
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    else:
        lr = cfg.lr
        logger = Logger(join(args.checkpoint, 'log.txt'))
        logger.set_names(['Epoch', 'LR', 'Train Loss'])

    # define loss function (criterion) and optimizer
    criterion1 = torch.nn.MSELoss().cuda()  # for Global loss
    criterion2 = torch.nn.MSELoss(reduce=False).cuda()  # for refine loss

    model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda()

    cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    print('    Total params: %.2fMB' %
          (sum(p.numel() for p in model.parameters()) / (1024 * 1024) * 4))

    train_loader = torch.utils.data.DataLoader(
        #MscocoMulti(cfg),
        KPloader(cfg),
        batch_size=cfg.batch_size * len(args.gpus))
    #, shuffle=True,
    #num_workers=args.workers, pin_memory=True)

    #torch.optim.Adam
    optimizer = AdaBound(model.parameters(),
                         lr=lr,
                         weight_decay=cfg.weight_decay)

    for epoch in range(args.start_epoch, args.epochs):
        lr = adjust_learning_rate(optimizer, epoch, cfg.lr_dec_epoch,
                                  cfg.lr_gamma)
        print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr))

        # train for one epoch
        train_loss = train(train_loader, model, [criterion1, criterion2],
                           optimizer)
        print('train_loss: ', train_loss)

        # append logger file
        logger.append([epoch + 1, lr, train_loss])

        #save_model({
        #    'epoch': epoch + 1,
        #    'state_dict': model.state_dict(),
        #    'optimizer' : optimizer.state_dict(),
        #}, checkpoint=args.checkpoint)

        state_dict = model.module.state_dict()
        for key in state_dict.keys():
            state_dict[key] = state_dict[key].cpu()
        torch.save({
            'epoch': epoch + 1,
            'state_dict': state_dict,
            'lr': lr,
        },
                   os.path.join(args.checkpoint,
                                "epoch" + str(epoch + 1) + "checkpoint.ckpt"))
        print("=> Save model done! the path: ", \
              os.path.join(args.checkpoint, "epoch" + str(epoch + 1) + "checkpoint.ckpt"))

    logger.close()