def main():
    opt = TrainOptions().parse()
    train_history = TrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    # layer_num = 2
    net = create_cu_net(neck_size=4,
                        growth_rate=32,
                        init_chan_num=128,
                        num_classes=num_classes,
                        layer_num=opt.layer_num,
                        max_link=1,
                        inter_loss_num=opt.layer_num)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print 'number of params: ', num1
    # print 'number of trainalbe params: ', num2
    # print 'number of conv params: ', num3
    # torch.save(net.state_dict(), 'test-model-size.pth.tar')
    # exit()
    # device = torch.device("cuda:0")
    # net = net.to(device)
    net = torch.nn.DataParallel(net).cuda()
    global quan_op
    quan_op = QuanOp(net)
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        # checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.save_prefix = exp_dir + '/'
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        opt.lr = optimizer.param_groups[0]['lr']
        resume_log = True
    else:
        checkpoint.save_prefix = exp_dir + '/'
        resume_log = False
    print 'save prefix: ', checkpoint.save_prefix
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)
    """optimizer"""
    # optimizer = torch.optim.SGD( net.parameters(), lr=opt.lr,
    #                             momentum=opt.momentum,
    #                             weight_decay=opt.weight_decay )
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=opt.lr, alpha=0.99,
    #                                 eps=1e-8, momentum=0, weight_decay=0)
    print type(optimizer)
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    logger = Logger(os.path.join(opt.exp_dir, opt.exp_id,
                                 'training-summary.txt'),
                    title='training-summary',
                    resume=resume_log)
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            joint_flip_index, num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   joint_flip_index,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        # visualizer.plot_train_history(train_history)
        logger.append([
            epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss,
            train_pckh, val_pckh
        ])
    logger.close()
Esempio n. 2
0
def main():
    opt = TrainOptions().parse()
    train_history = PoseTrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_path = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    net = create_hg(num_stacks=2,
                    num_modules=1,
                    num_classes=num_classes,
                    chan=256)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print('number of params: ', num1)
    # print('number of trainalbe params: ', num2)
    # print('number of conv params: ', num3)
    # exit()
    net = torch.nn.DataParallel(net).cuda()
    """optimizer"""
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.load_prefix_pose != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.load_prefix_pose)
        checkpoint.load_prefix = os.path.join(exp_dir,
                                              opt.load_prefix_pose)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        # trunc_index = checkpoint.save_prefix.index('lr-0.00025-80')
        # checkpoint.save_prefix = checkpoint.save_prefix[0:trunc_index]
        # checkpoint.save_prefix = exp_dir + '/'
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print('save prefix: ', checkpoint.save_prefix)
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    print(type(optimizer), optimizer.param_groups[0]['lr'])
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    # criterion = torch.nn.MSELoss(size_average=True).cuda()
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_pose != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('train_pckh', train_pckh),
                            ('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        visualizer.plot_train_history(train_history)
def main():
    opt = TrainOptions().parse()
    if opt.joint_dir == '':
        print('joint directory is null.')
        exit()
    joint_dir = os.path.join(opt.exp_dir, opt.exp_id,
                             opt.joint_dir + '-' + opt.load_prefix_pose[0:-1])
    # joint_dir = os.path.join(opt.exp_dir, opt.exp_id,
    #                          opt.joint_dir)
    if not os.path.isdir(joint_dir):
        os.makedirs(joint_dir)

    visualizer = Visualizer(opt)
    visualizer.log_path = joint_dir + '/' + 'train-log.txt'

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id,
    #                                      opt.joint_dir, 'joint-count.txt')
    if opt.dataset == 'mpii':
        num_classes = 16
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    """optimizer"""
    optimizer_hg = torch.optim.RMSprop(hg.parameters(),
                                       lr=opt.lr,
                                       alpha=0.99,
                                       eps=1e-8,
                                       momentum=0,
                                       weight_decay=0)
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    train_history_pose = PoseTrainHistory()
    checkpoint_hg = Checkpoint()
    if opt.load_checkpoint:
        checkpoint_hg.load_prefix = joint_dir + '/' + opt.load_prefix_pose[0:-1]
        checkpoint_hg.load_checkpoint(hg, optimizer_hg, train_history_pose)
    else:
        checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id) + \
                                    '/' + opt.load_prefix_pose[0:-1]
        checkpoint_hg.load_checkpoint(hg, optimizer_hg, train_history_pose)
        for param_group in optimizer_hg.param_groups:
            param_group['lr'] = opt.lr
    checkpoint_hg.save_prefix = joint_dir + '/pose-'
    # trunc_index = checkpoint.save_prefix_pose.index('lr-0.00025-85')
    # checkpoint.save_prefix_pose = checkpoint.save_prefix_pose[0:trunc_index]
    # print(checkpoint.save_prefix_pose)
    print 'hg optimizer: ', type(
        optimizer_hg), optimizer_hg.param_groups[0]['lr']

    agent_sr = model.create_asn(chan_in=256,
                                chan_out=256,
                                scale_num=len(dataset.scale_means),
                                rotation_num=len(dataset.rotation_means),
                                is_aug=True)
    agent_sr = torch.nn.DataParallel(agent_sr).cuda()
    optimizer_sr = torch.optim.RMSprop(agent_sr.parameters(),
                                       lr=opt.agent_lr,
                                       alpha=0.99,
                                       eps=1e-8,
                                       momentum=0,
                                       weight_decay=0)
    if opt.load_prefix_sr == '':
        print('please input the checkpoint name of the sr agent.')
        exit()
    train_history_sr = ASNTrainHistory()
    checkpoint_sr = Checkpoint()
    if opt.load_checkpoint:
        checkpoint_sr.load_prefix = joint_dir + '/' + opt.load_prefix_sr[0:-1]
        checkpoint_sr.load_checkpoint(agent_sr, optimizer_sr, train_history_sr)
    else:
        sr_pretrain_dir = os.path.join(
            opt.exp_dir, opt.exp_id,
            opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
        checkpoint_sr.load_prefix = sr_pretrain_dir + '/' + opt.load_prefix_sr[
            0:-1]
        checkpoint_sr.load_checkpoint(agent_sr, optimizer_sr, train_history_sr)
        for param_group in optimizer_sr.param_groups:
            param_group['lr'] = opt.agent_lr
    checkpoint_sr.save_prefix = joint_dir + '/agent-'
    # trunc_index = checkpoint.save_prefix_asn.index('lr-0.00025-80')
    # checkpoint.save_prefix_asn = checkpoint.save_prefix_asn[0:trunc_index]
    # print(checkpoint.save_prefix_asn)
    # adjust_lr(optimizer_asn, 5e-5)
    print 'agent optimizer: ', type(
        optimizer_sr), optimizer_sr.param_groups[0]['lr']

    train_dataset_hg = MPII('dataset/mpii-hr-lsp-normalizer.json',
                            '/bigdata1/zt53/data',
                            is_train=True)
    train_loader_hg = torch.utils.data.DataLoader(train_dataset_hg,
                                                  batch_size=opt.bs,
                                                  shuffle=True,
                                                  num_workers=opt.nThreads,
                                                  pin_memory=True)
    val_dataset_hg = MPII('dataset/mpii-hr-lsp-normalizer.json',
                          '/bigdata1/zt53/data',
                          is_train=False)
    val_loader_hg = torch.utils.data.DataLoader(val_dataset_hg,
                                                batch_size=opt.bs,
                                                shuffle=False,
                                                num_workers=opt.nThreads,
                                                pin_memory=True)
    train_dataset_agent = AGENT('dataset/mpii-hr-lsp-normalizer.json',
                                '/bigdata1/zt53/data',
                                separate_s_r=True)
    train_loader_agent = torch.utils.data.DataLoader(train_dataset_agent,
                                                     batch_size=opt.bs,
                                                     shuffle=True,
                                                     num_workers=opt.nThreads,
                                                     pin_memory=True)

    # idx = range(0, 16)
    # idx_pckh = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    if not opt.is_train:
        visualizer.log_path = joint_dir + '/' + 'val-log.txt'
        val_loss, val_pckh, predictions = validate(
            val_loader_hg, hg, train_history_pose.epoch[-1]['epoch'],
            visualizer, num_classes)
        checkpoint_hg.save_preds(predictions)
        return
    logger = Logger(joint_dir + '/' + 'pose-training-summary.txt',
                    title='pose-training-summary')
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train PCKh', 'Val PCKh'])
    """training and validation"""
    start_epoch_pose = train_history_pose.epoch[-1]['epoch'] + 1
    epoch_sr = train_history_sr.epoch[-1]['epoch'] + 1

    for epoch in range(start_epoch_pose, opt.nEpochs):
        adjust_lr(opt, optimizer_hg, epoch)
        # train hg for one epoch
        train_loss_pose, train_pckh = train_hg(train_loader_hg, hg,
                                               optimizer_hg, agent_sr, epoch,
                                               visualizer, opt)
        # util.save_drop_count(drop_count, lost_joint_count_path)
        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader_hg, hg, epoch,
                                                   visualizer, num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e_pose = OrderedDict([('epoch', epoch)])
        lr_pose = OrderedDict([('lr', optimizer_hg.param_groups[0]['lr'])])
        loss_pose = OrderedDict([('train_loss', train_loss_pose),
                                 ('val_loss', val_loss)])
        pckh = OrderedDict([('train_pckh', train_pckh),
                            ('val_pckh', val_pckh)])
        train_history_pose.update(e_pose, lr_pose, loss_pose, pckh)
        checkpoint_hg.save_checkpoint(hg, optimizer_hg, train_history_pose,
                                      predictions)
        visualizer.plot_train_history(train_history_pose)
        logger.append([
            epoch, optimizer_hg.param_groups[0]['lr'], train_loss_pose,
            val_loss, train_pckh, val_pckh
        ])
        # exit()
        # if train_history_pose.is_best:
        #     visualizer.display_imgpts(imgs, pred_pts, 4)

        # train agent_sr for one epoch
        train_loss_sr = train_agent_sr(train_loader_agent, hg, agent_sr,
                                       optimizer_sr, epoch_sr, visualizer, opt)
        e_sr = OrderedDict([('epoch', epoch_sr)])
        lr_sr = OrderedDict([('lr', optimizer_sr.param_groups[0]['lr'])])
        loss_sr = OrderedDict([('train_loss', train_loss_sr), ('val_loss', 0)])
        train_history_sr.update(e_sr, lr_sr, loss_sr)
        # print(train_history.lr[-1]['lr'])
        checkpoint_sr.save_checkpoint(agent_sr,
                                      optimizer_sr,
                                      train_history_sr,
                                      is_asn=True)
        visualizer.plot_train_history(train_history_sr, 'sr')
        # exit()
        epoch_sr += 1

    logger.close()