Пример #1
0
def main():
    opt = TrainOptions().parse()
    train_history = PoseTrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_path = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    net = create_hg(num_stacks=2,
                    num_modules=1,
                    num_classes=num_classes,
                    chan=256)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print('number of params: ', num1)
    # print('number of trainalbe params: ', num2)
    # print('number of conv params: ', num3)
    # exit()
    net = torch.nn.DataParallel(net).cuda()
    """optimizer"""
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.load_prefix_pose != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.load_prefix_pose)
        checkpoint.load_prefix = os.path.join(exp_dir,
                                              opt.load_prefix_pose)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        # trunc_index = checkpoint.save_prefix.index('lr-0.00025-80')
        # checkpoint.save_prefix = checkpoint.save_prefix[0:trunc_index]
        # checkpoint.save_prefix = exp_dir + '/'
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print('save prefix: ', checkpoint.save_prefix)
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    print(type(optimizer), optimizer.param_groups[0]['lr'])
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    # criterion = torch.nn.MSELoss(size_average=True).cuda()
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_pose != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('train_pckh', train_pckh),
                            ('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        visualizer.plot_train_history(train_history)
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistory()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id
    # if opt.dataset == 'mpii':
    num_classes = 16
    # layer_num = 2
    net = create_cu_net(neck_size=4,
                        growth_rate=32,
                        init_chan_num=128,
                        num_classes=num_classes,
                        layer_num=opt.layer_num,
                        max_link=1,
                        inter_loss_num=opt.layer_num)
    # num1 = get_n_params(net)
    # num2 = get_n_trainable_params(net)
    # num3 = get_n_conv_params(net)
    # print 'number of params: ', num1
    # print 'number of trainalbe params: ', num2
    # print 'number of conv params: ', num3
    # torch.save(net.state_dict(), 'test-model-size.pth.tar')
    # exit()
    # device = torch.device("cuda:0")
    # net = net.to(device)
    net = torch.nn.DataParallel(net).cuda()
    global quan_op
    quan_op = QuanOp(net)
    optimizer = torch.optim.RMSprop(net.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    """optionally resume from a checkpoint"""
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        # checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.save_prefix = exp_dir + '/'
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
        opt.lr = optimizer.param_groups[0]['lr']
        resume_log = True
    else:
        checkpoint.save_prefix = exp_dir + '/'
        resume_log = False
    print 'save prefix: ', checkpoint.save_prefix
    # model = {'state_dict': net.state_dict()}
    # save_path = checkpoint.save_prefix + 'test-model-size.pth.tar'
    # torch.save(model, save_path)
    # exit()
    """load data"""
    train_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(MPII(
        'dataset/mpii-hr-lsp-normalizer.json',
        '/bigdata1/zt53/data',
        is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)
    """optimizer"""
    # optimizer = torch.optim.SGD( net.parameters(), lr=opt.lr,
    #                             momentum=opt.momentum,
    #                             weight_decay=opt.weight_decay )
    # optimizer = torch.optim.RMSprop(net.parameters(), lr=opt.lr, alpha=0.99,
    #                                 eps=1e-8, momentum=0, weight_decay=0)
    print type(optimizer)
    # idx = range(0, 16)
    # idx = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    idx = [0, 1, 2, 3, 4, 5, 10, 11, 14, 15]
    logger = Logger(os.path.join(opt.exp_dir, opt.exp_id,
                                 'training-summary.txt'),
                    title='training-summary',
                    resume=resume_log)
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train Acc', 'Val Acc'])
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_pckh, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer, idx,
            joint_flip_index, num_classes)
        checkpoint.save_preds(predictions)
        return
    """training and validation"""
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        adjust_lr(opt, optimizer, epoch)
        # # train for one epoch
        train_loss, train_pckh = train(train_loader, net, optimizer, epoch,
                                       visualizer, idx, opt)

        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader, net, epoch,
                                                   visualizer, idx,
                                                   joint_flip_index,
                                                   num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        pckh = OrderedDict([('val_pckh', val_pckh)])
        train_history.update(e, lr, loss, pckh)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        # visualizer.plot_train_history(train_history)
        logger.append([
            epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss,
            train_pckh, val_pckh
        ])
    logger.close()
def main():
    opt = TrainOptions().parse()
    if opt.sr_dir == '':
        print('sr directory is null.')
        exit()
    sr_pretrain_dir = os.path.join(
        opt.exp_dir, opt.exp_id, opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
    if not os.path.isdir(sr_pretrain_dir):
        os.makedirs(sr_pretrain_dir)
    train_history = ASNTrainHistory()
    # print(train_history.lr)
    # exit()
    checkpoint_agent = Checkpoint()
    visualizer = Visualizer(opt)
    visualizer.log_path = sr_pretrain_dir + '/' + 'log.txt'
    train_scale_path = sr_pretrain_dir + '/' + 'train_scales.txt'
    train_rotation_path = sr_pretrain_dir + '/' + 'train_rotations.txt'
    val_scale_path = sr_pretrain_dir + '/' + 'val_scales.txt'
    val_rotation_path = sr_pretrain_dir + '/' + 'val_rotations.txt'

    # with open(visualizer.log_path, 'a+') as log_file:
    #     log_file.write(opt.resume_prefix_pose + '.pth.tar\n')
    # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.astn_dir, 'joint-count.txt')
    # print("=> log saved to path '{}'".format(visualizer.log_path))
    # if opt.dataset == 'mpii':
    #     num_classes = 16
    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    print 'collecting training scale and rotation distributions ...\n'
    train_scale_distri = read_grnd_distri_from_txt(train_scale_path)
    train_rotation_distri = read_grnd_distri_from_txt(train_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=True,
                   grnd_scale_distri=train_scale_distri,
                   grnd_rotation_distri=train_rotation_distri)
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    print 'collecting validation scale and rotation distributions ...\n'
    val_scale_distri = read_grnd_distri_from_txt(val_scale_path)
    val_rotation_distri = read_grnd_distri_from_txt(val_rotation_path)
    dataset = MPII('dataset/mpii-hr-lsp-normalizer.json',
                   '/bigdata1/zt53/data',
                   is_train=False,
                   grnd_scale_distri=val_scale_distri,
                   grnd_rotation_distri=val_rotation_distri)
    val_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    agent = model.create_asn(chan_in=256,
                             chan_out=256,
                             scale_num=len(dataset.scale_means),
                             rotation_num=len(dataset.rotation_means),
                             is_aug=True)
    agent = torch.nn.DataParallel(agent).cuda()
    optimizer = torch.optim.RMSprop(agent.parameters(),
                                    lr=opt.lr,
                                    alpha=0.99,
                                    eps=1e-8,
                                    momentum=0,
                                    weight_decay=0)
    # optimizer = torch.optim.Adam(agent.parameters(), lr=opt.agent_lr)
    if opt.load_prefix_sr == '':
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/'
    else:
        checkpoint_agent.save_prefix = sr_pretrain_dir + '/' + opt.load_prefix_sr
        checkpoint_agent.load_prefix = checkpoint_agent.save_prefix[0:-1]
        checkpoint_agent.load_checkpoint(agent, optimizer, train_history)
        # adjust_lr(optimizer, opt.lr)
        # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id, opt.asdn_dir, 'joint-count-finetune.txt')
    print 'agent: ', type(optimizer), optimizer.param_groups[0]['lr']

    if opt.dataset == 'mpii':
        num_classes = 16
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    checkpoint_hg = Checkpoint()
    # checkpoint_hg.save_prefix = os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix_pose)
    checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id,
                                             opt.load_prefix_pose)[0:-1]
    checkpoint_hg.load_checkpoint(hg)

    logger = Logger(sr_pretrain_dir + '/' + 'training-summary.txt',
                    title='training-summary')
    logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss'])
    """training and validation"""
    start_epoch = 0
    if opt.load_prefix_sr != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1
    for epoch in range(start_epoch, opt.nEpochs):
        # train for one epoch
        train_loss = train(train_loader, hg, agent, optimizer, epoch,
                           visualizer, opt)
        val_loss = validate(val_loader, hg, agent, epoch, visualizer, opt)
        # update training history
        e = OrderedDict([('epoch', epoch)])
        lr = OrderedDict([('lr', optimizer.param_groups[0]['lr'])])
        loss = OrderedDict([('train_loss', train_loss),
                            ('val_loss', val_loss)])
        # pckh = OrderedDict( [('val_pckh', val_pckh)] )
        train_history.update(e, lr, loss)
        # print(train_history.lr[-1]['lr'])
        checkpoint_agent.save_checkpoint(agent,
                                         optimizer,
                                         train_history,
                                         is_asn=True)
        visualizer.plot_train_history(train_history, 'sr')
        logger.append(
            [epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss])
    logger.close()
def main():
    opt = TrainOptions().parse()
    if opt.joint_dir == '':
        print('joint directory is null.')
        exit()
    joint_dir = os.path.join(opt.exp_dir, opt.exp_id,
                             opt.joint_dir + '-' + opt.load_prefix_pose[0:-1])
    # joint_dir = os.path.join(opt.exp_dir, opt.exp_id,
    #                          opt.joint_dir)
    if not os.path.isdir(joint_dir):
        os.makedirs(joint_dir)

    visualizer = Visualizer(opt)
    visualizer.log_path = joint_dir + '/' + 'train-log.txt'

    os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    # lost_joint_count_path = os.path.join(opt.exp_dir, opt.exp_id,
    #                                      opt.joint_dir, 'joint-count.txt')
    if opt.dataset == 'mpii':
        num_classes = 16
    hg = model.create_hg(num_stacks=2,
                         num_modules=1,
                         num_classes=num_classes,
                         chan=256)
    hg = torch.nn.DataParallel(hg).cuda()
    """optimizer"""
    optimizer_hg = torch.optim.RMSprop(hg.parameters(),
                                       lr=opt.lr,
                                       alpha=0.99,
                                       eps=1e-8,
                                       momentum=0,
                                       weight_decay=0)
    if opt.load_prefix_pose == '':
        print('please input the checkpoint name of the pose model')
        exit()
    train_history_pose = PoseTrainHistory()
    checkpoint_hg = Checkpoint()
    if opt.load_checkpoint:
        checkpoint_hg.load_prefix = joint_dir + '/' + opt.load_prefix_pose[0:-1]
        checkpoint_hg.load_checkpoint(hg, optimizer_hg, train_history_pose)
    else:
        checkpoint_hg.load_prefix = os.path.join(opt.exp_dir, opt.exp_id) + \
                                    '/' + opt.load_prefix_pose[0:-1]
        checkpoint_hg.load_checkpoint(hg, optimizer_hg, train_history_pose)
        for param_group in optimizer_hg.param_groups:
            param_group['lr'] = opt.lr
    checkpoint_hg.save_prefix = joint_dir + '/pose-'
    # trunc_index = checkpoint.save_prefix_pose.index('lr-0.00025-85')
    # checkpoint.save_prefix_pose = checkpoint.save_prefix_pose[0:trunc_index]
    # print(checkpoint.save_prefix_pose)
    print 'hg optimizer: ', type(
        optimizer_hg), optimizer_hg.param_groups[0]['lr']

    agent_sr = model.create_asn(chan_in=256,
                                chan_out=256,
                                scale_num=len(dataset.scale_means),
                                rotation_num=len(dataset.rotation_means),
                                is_aug=True)
    agent_sr = torch.nn.DataParallel(agent_sr).cuda()
    optimizer_sr = torch.optim.RMSprop(agent_sr.parameters(),
                                       lr=opt.agent_lr,
                                       alpha=0.99,
                                       eps=1e-8,
                                       momentum=0,
                                       weight_decay=0)
    if opt.load_prefix_sr == '':
        print('please input the checkpoint name of the sr agent.')
        exit()
    train_history_sr = ASNTrainHistory()
    checkpoint_sr = Checkpoint()
    if opt.load_checkpoint:
        checkpoint_sr.load_prefix = joint_dir + '/' + opt.load_prefix_sr[0:-1]
        checkpoint_sr.load_checkpoint(agent_sr, optimizer_sr, train_history_sr)
    else:
        sr_pretrain_dir = os.path.join(
            opt.exp_dir, opt.exp_id,
            opt.sr_dir + '-' + opt.load_prefix_pose[0:-1])
        checkpoint_sr.load_prefix = sr_pretrain_dir + '/' + opt.load_prefix_sr[
            0:-1]
        checkpoint_sr.load_checkpoint(agent_sr, optimizer_sr, train_history_sr)
        for param_group in optimizer_sr.param_groups:
            param_group['lr'] = opt.agent_lr
    checkpoint_sr.save_prefix = joint_dir + '/agent-'
    # trunc_index = checkpoint.save_prefix_asn.index('lr-0.00025-80')
    # checkpoint.save_prefix_asn = checkpoint.save_prefix_asn[0:trunc_index]
    # print(checkpoint.save_prefix_asn)
    # adjust_lr(optimizer_asn, 5e-5)
    print 'agent optimizer: ', type(
        optimizer_sr), optimizer_sr.param_groups[0]['lr']

    train_dataset_hg = MPII('dataset/mpii-hr-lsp-normalizer.json',
                            '/bigdata1/zt53/data',
                            is_train=True)
    train_loader_hg = torch.utils.data.DataLoader(train_dataset_hg,
                                                  batch_size=opt.bs,
                                                  shuffle=True,
                                                  num_workers=opt.nThreads,
                                                  pin_memory=True)
    val_dataset_hg = MPII('dataset/mpii-hr-lsp-normalizer.json',
                          '/bigdata1/zt53/data',
                          is_train=False)
    val_loader_hg = torch.utils.data.DataLoader(val_dataset_hg,
                                                batch_size=opt.bs,
                                                shuffle=False,
                                                num_workers=opt.nThreads,
                                                pin_memory=True)
    train_dataset_agent = AGENT('dataset/mpii-hr-lsp-normalizer.json',
                                '/bigdata1/zt53/data',
                                separate_s_r=True)
    train_loader_agent = torch.utils.data.DataLoader(train_dataset_agent,
                                                     batch_size=opt.bs,
                                                     shuffle=True,
                                                     num_workers=opt.nThreads,
                                                     pin_memory=True)

    # idx = range(0, 16)
    # idx_pckh = [e for e in idx if e not in (6, 7, 8, 9, 12, 13)]
    if not opt.is_train:
        visualizer.log_path = joint_dir + '/' + 'val-log.txt'
        val_loss, val_pckh, predictions = validate(
            val_loader_hg, hg, train_history_pose.epoch[-1]['epoch'],
            visualizer, num_classes)
        checkpoint_hg.save_preds(predictions)
        return
    logger = Logger(joint_dir + '/' + 'pose-training-summary.txt',
                    title='pose-training-summary')
    logger.set_names(
        ['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train PCKh', 'Val PCKh'])
    """training and validation"""
    start_epoch_pose = train_history_pose.epoch[-1]['epoch'] + 1
    epoch_sr = train_history_sr.epoch[-1]['epoch'] + 1

    for epoch in range(start_epoch_pose, opt.nEpochs):
        adjust_lr(opt, optimizer_hg, epoch)
        # train hg for one epoch
        train_loss_pose, train_pckh = train_hg(train_loader_hg, hg,
                                               optimizer_hg, agent_sr, epoch,
                                               visualizer, opt)
        # util.save_drop_count(drop_count, lost_joint_count_path)
        # evaluate on validation set
        val_loss, val_pckh, predictions = validate(val_loader_hg, hg, epoch,
                                                   visualizer, num_classes)
        # visualizer.display_imgpts(imgs, pred_pts, 4)
        # exit()
        # update training history
        e_pose = OrderedDict([('epoch', epoch)])
        lr_pose = OrderedDict([('lr', optimizer_hg.param_groups[0]['lr'])])
        loss_pose = OrderedDict([('train_loss', train_loss_pose),
                                 ('val_loss', val_loss)])
        pckh = OrderedDict([('train_pckh', train_pckh),
                            ('val_pckh', val_pckh)])
        train_history_pose.update(e_pose, lr_pose, loss_pose, pckh)
        checkpoint_hg.save_checkpoint(hg, optimizer_hg, train_history_pose,
                                      predictions)
        visualizer.plot_train_history(train_history_pose)
        logger.append([
            epoch, optimizer_hg.param_groups[0]['lr'], train_loss_pose,
            val_loss, train_pckh, val_pckh
        ])
        # exit()
        # if train_history_pose.is_best:
        #     visualizer.display_imgpts(imgs, pred_pts, 4)

        # train agent_sr for one epoch
        train_loss_sr = train_agent_sr(train_loader_agent, hg, agent_sr,
                                       optimizer_sr, epoch_sr, visualizer, opt)
        e_sr = OrderedDict([('epoch', epoch_sr)])
        lr_sr = OrderedDict([('lr', optimizer_sr.param_groups[0]['lr'])])
        loss_sr = OrderedDict([('train_loss', train_loss_sr), ('val_loss', 0)])
        train_history_sr.update(e_sr, lr_sr, loss_sr)
        # print(train_history.lr[-1]['lr'])
        checkpoint_sr.save_checkpoint(agent_sr,
                                      optimizer_sr,
                                      train_history_sr,
                                      is_asn=True)
        visualizer.plot_train_history(train_history_sr, 'sr')
        # exit()
        epoch_sr += 1

    logger.close()
Пример #5
0
def main():
    opt = TrainOptions().parse() 
    train_history = TrainHistoryFace()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + 'log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    num_classes = opt.class_num

    if not opt.slurm:
        os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    layer_num = opt.layer_num
    order     = opt.order
    net = create_cu_net(neck_size= 4, growth_rate= 32, init_chan_num= 128, 
                class_num= num_classes, layer_num= layer_num, order= order, 
                loss_num= layer_num, use_spatial_transformer= opt.stn, 
                mlp_tot_layers= opt.mlp_tot_layers, mlp_hidden_units= opt.mlp_hidden_units,
                get_mean_from_mlp= opt.get_mean_from_mlp)

    # Load the pre-trained model
    saved_wt_file = opt.saved_wt_file
    if saved_wt_file == "":
        print("=> Training from scratch")
    else:
        print("=> Loading weights from " + saved_wt_file)
        checkpoint_t = torch.load(saved_wt_file)
        state_dict = checkpoint_t['state_dict']

        tt_names=[]
        for names in net.state_dict():
            tt_names.append(names)

        for name, param in state_dict.items():
            name = name[7:]
            if name not in net.state_dict():
                print("=> not load weights '{}'".format(name))
                continue
            if isinstance(param, Parameter):
                param = param.data
            if (net.state_dict()[name].shape[0] == param.shape[0]):
                net.state_dict()[name].copy_(param)
            else:
                print("First dim different. Not loading weights {}".format(name))


    if (opt.freeze):
        print("\n\t\tFreezing basenet parameters\n")
        for param in net.parameters():
            param.requires_grad = False
        """
        for i in range(layer_num):
            net.choleskys[i].fc_1.bias.requires_grad   = True
            net.choleskys[i].fc_1.weight.requires_grad = True
            net.choleskys[i].fc_2.bias.requires_grad   = True
            net.choleskys[i].fc_2.weight.requires_grad = True
            net.choleskys[i].fc_3.bias.requires_grad   = True
            net.choleskys[i].fc_3.weight.requires_grad = True
        """

        net.cholesky.fc_1.bias.requires_grad   = True
        net.cholesky.fc_1.weight.requires_grad = True
        net.cholesky.fc_2.bias.requires_grad   = True
        net.cholesky.fc_2.weight.requires_grad = True
        net.cholesky.fc_3.bias.requires_grad   = True
        net.cholesky.fc_3.weight.requires_grad = True

    else:
        print("\n\t\tNot freezing anything. Tuning every parameter\n")
        for param in net.parameters():
            param.requires_grad = True

    net = torch.nn.DataParallel(net).cuda() # use multiple GPUs

    # Optimizer
    if opt.optimizer == "rmsprop":
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, net.parameters()), lr=opt.lr, alpha=0.99,
                                        eps=1e-8, momentum=0, weight_decay=0)
    elif opt.optimizer == "adam":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=opt.lr)
    else:
        print("Unknown Optimizer. Aborting!!!")
        sys.exit(0)
    print type(optimizer)

    # Optionally resume from a checkpoint
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print("Save prefix                           = {}".format(checkpoint.save_prefix))

    # Load data
    json_path  = opt.json_path
    train_json = opt.train_json
    val_json   = opt.val_json

    print("Path added to each image path in JSON = {}".format(json_path))
    print("Train JSON path                       = {}".format(train_json))
    print("Val JSON path                         = {}".format(val_json))

    if opt.bulat_aug:
        # Use Bulat et al Augmentation Scheme
        train_loader = torch.utils.data.DataLoader(
             FACE(train_json, json_path, is_train= True, scale_factor= 0.2, rot_factor= 50, use_occlusion= True, keep_pts_inside= True),
             batch_size=opt.bs, shuffle= True,
             num_workers=opt.nThreads, pin_memory= True)
    else:
        train_loader = torch.utils.data.DataLoader(
             FACE(train_json, json_path, is_train= True, keep_pts_inside= True),
             batch_size=opt.bs, shuffle= True,
             num_workers=opt.nThreads, pin_memory= True)

    val_loader = torch.utils.data.DataLoader(
         FACE(val_json, json_path, is_train=False),
         batch_size=opt.bs, shuffle=False,
         num_workers=opt.nThreads, pin_memory=True)

    logger = Logger(os.path.join(opt.exp_dir, opt.exp_id, opt.resume_prefix+'face-training-log.txt'),
    title='face-training-summary')
    logger.set_names(['Epoch', 'LR', 'Train Loss', 'Val Loss', 'Train RMSE', 'Val RMSE', 'Train RMSE Box', 'Val RMSE Box', 'Train RMSE Meta', 'Val RMSE Meta'])
    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id, 'val_log.txt')
        val_loss, val_rmse, predictions = validate(val_loader, net,
                train_history.epoch[-1]['epoch'], visualizer, num_classes, flip_index)
        checkpoint.save_preds(predictions)
        return

    global weights_HG
    weights_HG  = [float(x) for x in opt.hg_wt.split(",")] 

    if opt.is_covariance:
        print("Covariance used from the heatmap")
    else:
        print("Covariance calculated from MLP")

    if opt.stn:
        print("Using spatial transformer on heatmaps")
    print ("Postprocessing applied                = {}".format(opt.pp)) 
    if (opt.smax):
        print("Scaled softmax used with tau          = {}".format(opt.tau))
    else:
        print("No softmax used")

    print("Individual Hourglass loss weights")
    print(weights_HG)
    print("wt_MSE (tradeoff between GLL and MSE in each hourglass)= " + str(opt.wt_mse))
    print("wt_gauss_regln (tradeoff between GLL and Gaussian Regularisation in each hourglass)= " + str(opt.wt_gauss_regln))

    if opt.bulat_aug:
        print("Using Bulat et al, ICCV 2017 Augmentation Scheme")

    print("Using Learning Policy {}".format(opt.lr_policy))
    chosen_lr_policy = dict_of_functions[opt.lr_policy]

    # Optionally resume from a checkpoint
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    # Training and validation
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    train_loss_orig_epoch   = []
    train_loss_gau_t1_epoch = []
    train_loss_gau_t2_epoch = []
    train_nme_orig_epoch    = []
    train_nme_gau_epoch     = []
    train_nme_new_epoch     = []

    val_loss_orig_epoch     = []
    val_loss_gau_t1_epoch   = []
    val_loss_gau_t2_epoch   = []
    val_nme_orig_epoch      = []
    val_nme_gau_epoch       = []
    val_nme_new_epoch       = []

    for epoch in range(start_epoch, opt.nEpochs):
        chosen_lr_policy(opt, optimizer, epoch)
        # Train for one epoch
        train_loss, train_loss_mse,train_loss_gau_t1, train_loss_gau_t2,train_rmse_orig, train_rmse_gau, train_rmse_new_gd_box, train_rmse_new_meta_box  = train(train_loader, net, optimizer, epoch, visualizer, opt)
        #train_loss_gau_epoch.append(train_loss_gau)
        train_loss_gau_t1_epoch.append(train_loss_gau_t1)
        train_loss_gau_t2_epoch.append(train_loss_gau_t2)
        train_nme_orig_epoch.append(train_rmse_orig)
        train_nme_gau_epoch.append(train_rmse_gau)
        train_loss_orig_epoch.append(train_loss_mse)

        # Evaluate on validation set
        val_loss, val_loss_mse, val_loss_gau_t1, val_loss_gau_t2 , val_rmse_orig, val_rmse_gau, val_rmse_new_gd_box, val_rmse_new_meta_box, predictions= validate(val_loader, net, epoch, visualizer, opt, num_classes, flip_index)
        val_loss_orig_epoch.append(val_loss_mse)
        val_loss_gau_t1_epoch.append(val_loss_gau_t1)
        val_loss_gau_t2_epoch.append(val_loss_gau_t2)
        val_nme_orig_epoch.append(val_rmse_orig)
        val_nme_gau_epoch.append(val_rmse_gau)

        # Update training history
        e = OrderedDict( [('epoch', epoch)] )
        lr = OrderedDict( [('lr', optimizer.param_groups[0]['lr'])] )
        loss = OrderedDict( [('train_loss', train_loss),('val_loss', val_loss)] )
        rmse = OrderedDict( [('val_rmse', val_rmse_gau)] )
        train_history.update(e, lr, loss, rmse)
        checkpoint.save_checkpoint(net, optimizer, train_history, predictions)
        visualizer.plot_train_history_face(train_history)
        logger.append([epoch, optimizer.param_groups[0]['lr'], train_loss, val_loss, train_rmse_gau, val_rmse_gau, train_rmse_new_gd_box, val_rmse_new_gd_box, train_rmse_new_meta_box, val_rmse_new_meta_box])

    logger.close()
Пример #6
0
def main():
    opt = TrainOptions().parse()
    train_history = TrainHistoryFace()
    checkpoint = Checkpoint()
    visualizer = Visualizer(opt)
    exp_dir = os.path.join(opt.exp_dir, opt.exp_id)
    log_name = opt.vis_env + '_val_log.txt'
    visualizer.log_name = os.path.join(exp_dir, log_name)
    num_classes = opt.class_num

    if not opt.slurm:
        os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id

    layer_num = opt.layer_num
    order = opt.order
    net = create_cu_net(neck_size=4,
                        growth_rate=32,
                        init_chan_num=128,
                        class_num=num_classes,
                        layer_num=layer_num,
                        order=order,
                        loss_num=layer_num,
                        use_spatial_transformer=opt.stn,
                        mlp_tot_layers=opt.mlp_tot_layers,
                        mlp_hidden_units=opt.mlp_hidden_units,
                        get_mean_from_mlp=opt.get_mean_from_mlp)

    # Load the pre-trained model
    saved_wt_file = opt.saved_wt_file
    print("Loading weights from " + saved_wt_file)
    checkpoint_t = torch.load(saved_wt_file)
    state_dict = checkpoint_t['state_dict']

    for name, param in state_dict.items():
        name = name[7:]
        if name not in net.state_dict():
            print("=> not load weights '{}'".format(name))
            continue
        if isinstance(param, Parameter):
            param = param.data
        net.state_dict()[name].copy_(param)

    net = torch.nn.DataParallel(net).cuda()  # use multiple GPUs

    # Optimizer
    if opt.optimizer == "rmsprop":
        optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad,
                                               net.parameters()),
                                        lr=opt.lr,
                                        alpha=0.99,
                                        eps=1e-8,
                                        momentum=0,
                                        weight_decay=0)
    elif opt.optimizer == "adam":
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            net.parameters()),
                                     lr=opt.lr)
    else:
        print("Unknown Optimizer. Aborting!!!")
        sys.exit(0)
    print(type(optimizer))

    # Optionally resume from a checkpoint
    if opt.resume_prefix != '':
        # if 'pth' in opt.resume_prefix:
        #     trunc_index = opt.resume_prefix.index('pth')
        #     opt.resume_prefix = opt.resume_prefix[0:trunc_index - 1]
        checkpoint.save_prefix = os.path.join(exp_dir, opt.resume_prefix)
        checkpoint.load_prefix = os.path.join(exp_dir, opt.resume_prefix)[0:-1]
        checkpoint.load_checkpoint(net, optimizer, train_history)
    else:
        checkpoint.save_prefix = exp_dir + '/'
    print("Save prefix                           = {}".format(
        checkpoint.save_prefix))

    # Load data
    json_path = opt.json_path
    train_json = opt.train_json
    val_json = opt.val_json

    print("Path added to each image path in JSON = {}".format(json_path))
    print("Train JSON path                       = {}".format(train_json))
    print("Val JSON path                         = {}".format(val_json))

    # This train loader is useless
    train_loader = torch.utils.data.DataLoader(FACE(train_json,
                                                    json_path,
                                                    is_train=True),
                                               batch_size=opt.bs,
                                               shuffle=True,
                                               num_workers=opt.nThreads,
                                               pin_memory=True)

    val_loader = torch.utils.data.DataLoader(FACE(val_json,
                                                  json_path,
                                                  is_train=False),
                                             batch_size=opt.bs,
                                             shuffle=False,
                                             num_workers=opt.nThreads,
                                             pin_memory=True)

    if not opt.is_train:
        visualizer.log_path = os.path.join(opt.exp_dir, opt.exp_id,
                                           'val_log.txt')
        val_loss, val_rmse, predictions = validate(
            val_loader, net, train_history.epoch[-1]['epoch'], visualizer,
            num_classes, flip_index)
        checkpoint.save_preds(predictions)
        return

    global f_path
    global weights_HG

    f_path = exp_dir
    weights_HG = [float(x) for x in opt.hg_wt.split(",")]

    print("Postprocessing applied                = {}".format(opt.pp))
    if (opt.smax):
        print("Scaled softmax used with tau          = {}".format(opt.tau))
    else:
        print("No softmax used")

    if opt.is_covariance:
        print("Covariance used from the heatmap")
    else:
        print("Covariance calculated from MLP")

    print("Individual Hourglass loss weights")
    print(weights_HG)
    print("wt_MSE (tradeoff between GLL and MSE in each hourglass)= " +
          str(opt.wt_mse))
    print(
        "wt_gauss_regln (tradeoff between GLL and Gaussian Regularisation in each hourglass)= "
        + str(opt.wt_gauss_regln))

    # Optionally resume from a checkpoint
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    # Training and validation
    start_epoch = 0
    if opt.resume_prefix != '':
        start_epoch = train_history.epoch[-1]['epoch'] + 1

    train_loss_orig_epoch = []
    train_loss_gau_t1_epoch = []
    train_loss_gau_t2_epoch = []
    train_nme_orig_epoch = []
    train_nme_gau_epoch = []
    train_nme_new_epoch = []

    val_loss_orig_epoch = []
    val_loss_gau_t1_epoch = []
    val_loss_gau_t2_epoch = []
    val_nme_orig_epoch = []
    val_nme_gau_epoch = []
    val_nme_new_epoch = []

    for epoch in range(1):
        # Evaluate on validation set
        val_loss, val_loss_mse, val_loss_gau_t1, val_loss_gau_t2, val_rmse_orig, val_rmse_gau, val_rmse_new_box, predictions = validate(
            val_loader, net, epoch, visualizer, opt, num_classes, flip_index)
        val_loss_orig_epoch.append(val_loss_mse)
        val_loss_gau_t1_epoch.append(val_loss_gau_t1)
        val_loss_gau_t2_epoch.append(val_loss_gau_t2)
        val_nme_orig_epoch.append(val_rmse_orig)
        val_nme_gau_epoch.append(val_rmse_gau)