Exemplo n.º 1
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    val_dataset = data_pth.view_data(config.view_net.data_root,
                                     status=STATUS_TEST,
                                     base_model_name=config.base_model_name)

    val_loader = DataLoader(val_dataset, batch_size=config.view_net.train.batch_sz,
                            num_workers=config.num_workers,shuffle=False)


    # create model
    net = MVCNN()
    net = net.to(device=config.device)
    net = nn.DataParallel(net)

    print(f'loading pretrained model from {config.view_net.ckpt_file}')
    checkpoint = torch.load(config.view_net.ckpt_file)
    net.module.load_state_dict(checkpoint['model'])
    best_prec1 = checkpoint['best_prec1']

    with torch.no_grad():
        prec1 = validate(val_loader, net)

    print('curr accuracy: ', prec1)
    print('best accuracy: ', best_prec1)

    print('Train Finished!')
Exemplo n.º 2
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    val_dataset = data_pth.pc_data(config.pc_net.data_root, status=STATUS_TEST)

    val_loader = DataLoader(val_dataset, batch_size=config.pc_net.validation.batch_sz,
                            num_workers=config.num_workers,shuffle=True,drop_last=True)

    # create model
    net = DGCNN(n_neighbor=config.pc_net.n_neighbor,num_classes=config.pc_net.num_classes)
    net = torch.nn.DataParallel(net)
    net = net.to(device=config.device)
    optimizer = optim.Adam(net.parameters(), config.pc_net.train.lr,
                          weight_decay=config.pc_net.train.weight_decay)

    print(f'loading pretrained model from {config.pc_net.ckpt_file}')
    checkpoint = torch.load(config.pc_net.ckpt_file)
    net.module.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    best_prec1 = checkpoint['best_prec1']
    resume_epoch = checkpoint['epoch']

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    # for p in net.module.feature.parameters():
    #     p.requires_grad = False

    with torch.no_grad():
        prec1 = validate(val_loader, net, resume_epoch)

    print('curr accuracy: ', prec1)
    print('best accuracy: ', best_prec1)
Exemplo n.º 3
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    val_dataset = view_data(config.view_net.data_root,
                            status=STATUS_TEST,
                            base_model_name=config.base_model_name)

    val_loader = DataLoader(val_dataset,
                            batch_size=config.view_net.test.batch_sz,
                            num_workers=config.num_workers,
                            shuffle=True)

    # create model
    net = Net(pretrained=True)
    net = net.to(device=config.device)
    net = nn.DataParallel(net)

    print(f'loading pretrained model from {config.view_net.ckpt_file}')
    checkpoint = torch.load(config.view_net.ckpt_file)
    net.module.load_state_dict(checkpoint['model'])

    with torch.no_grad():
        validate(val_loader, net)

    print('test Finished!')
Exemplo n.º 4
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    train_dataset = data_pth.multi_pc_data(config.multi_pc_net.data_root, status=STATUS_TRAIN)
    val_dataset = data_pth.multi_pc_data(config.multi_pc_net.data_root, status=STATUS_TEST)

    train_loader = DataLoader(train_dataset, batch_size=config.multi_pc_net.train.batch_sz,
                              num_workers=config.num_workers,shuffle = True,drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config.multi_pc_net.validation.batch_sz,
                            num_workers=config.num_workers,shuffle=True,drop_last=True)

    best_prec1 = 0
    best_map=0
    resume_epoch = 0
    # create model
    net = DGCNN_Multi_Cloud(n_neighbor=config.multi_pc_net.n_neighbor,num_classes=config.multi_pc_net.num_classes)
    net = torch.nn.DataParallel(net)
    net = net.to(device=config.device)
    optimizer = optim.Adam(net.parameters(), config.multi_pc_net.train.lr,
                          weight_decay=config.multi_pc_net.train.weight_decay)

    if config.multi_pc_net.train.resume:
        print(f'loading pretrained model from {config.multi_pc_net.ckpt_file}')
        checkpoint = torch.load(config.multi_pc_net.ckpt_file)
        net.module.load_state_dict({k[7:]: v for k, v in checkpoint['model'].items()})
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_prec1 = checkpoint['best_prec1']
        if config.multi_pc_net.train.resume_epoch is not None:
            resume_epoch = config.multi_pc_net.train.resume_epoch
        else:
            resume_epoch = checkpoint['epoch'] + 1

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 20, 0.7)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    for epoch in range(resume_epoch, config.multi_pc_net.train.max_epoch):

        lr_scheduler.step(epoch=epoch)
        # train
        train(train_loader, net, criterion, optimizer, epoch)
        # validation
        with torch.no_grad():
            prec1,retrieval_map = validate(val_loader, net, epoch)

        # save checkpoints
        if prec1 > best_prec1:
            best_prec1 = prec1
            save_ckpt(epoch, best_prec1, net, optimizer)
        if retrieval_map > best_map:
            best_map=retrieval_map


        # save_record(epoch, prec1, net.module)
        print('curr accuracy: ', prec1)
        print('best accuracy: ', best_prec1)
        print('best map: ', best_map)

    print('Train Finished!')
Exemplo n.º 5
0
def main():
    config.init_env()
    if config.process == "TRAIN":
        train_process()
    elif config.process == "VAL":
        evaluation_process()
    elif config.process == "VIS":
        visualization_process()
    elif config.process == "TEST":
        extract_test_rst_process()
    else:
        raise NotImplementedError
Exemplo n.º 6
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    val_dataset = pc_view_data(config.pv_net.pc_root,
                               config.pv_net.view_root,
                               status=STATUS_TEST,
                               base_model_name=config.base_model_name)
    val_loader = DataLoader(val_dataset,
                            batch_size=config.pv_net.train.batch_sz,
                            num_workers=config.num_workers,
                            shuffle=True)

    # create model
    net = PVNet2_v9()
    net = torch.nn.DataParallel(net)
    net = net.to(device=config.device)
    optimizer_all = optim.SGD(net.parameters(),
                              config.pv_net.train.all_lr,
                              momentum=config.pv_net.train.momentum,
                              weight_decay=config.pv_net.train.weight_decay)

    print(f'loading pretrained model from {config.pv_net.ckpt_file}')
    checkpoint = torch.load(config.pv_net.ckpt_file)
    state_dict = checkpoint['model']
    # net.module.load_state_dict({k[7:]: v for k, v in state_dict.items()})
    net.module.load_state_dict(state_dict)
    optimizer_all.load_state_dict(checkpoint['optimizer_all'])
    best_prec1 = checkpoint['best_prec1']
    resume_epoch = checkpoint['epoch']

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer_all, 5, 0.5)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    # for p in net.module.feature.parameters():
    #     p.requires_grad = False

    with torch.no_grad():
        prec1, Map = validate(val_loader, net, resume_epoch)

    print('curr accuracy: ', prec1)
    print('best accuracy: ', best_prec1)
Exemplo n.º 7
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    train_dataset = pc_view_data(config.pv_net.pc_root,
                                 config.pv_net.view_root,
                                 status=STATUS_TRAIN,
                                 base_model_name=config.base_model_name)
    val_dataset = pc_view_data(config.pv_net.pc_root,
                               config.pv_net.view_root,
                               status=STATUS_TEST,
                               base_model_name=config.base_model_name)

    train_loader = DataLoader(train_dataset, batch_size=config.pv_net.train.batch_sz,
                              num_workers=config.num_workers,shuffle = True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config.pv_net.train.batch_sz,
                            num_workers=config.num_workers,shuffle=True, drop_last=True)

    best_prec1 = 0
    resume_epoch = 0

    epoch_pc_view = 0
    epoch_pc = 0

    # create model
    net = PVNet2_lf()
    net = net.to(device=config.device)
    net = nn.DataParallel(net)

    # optimizer
    # fc_param = [{'params': v} for k, v in net.named_parameters() if 'mlp'or'trans_ft' in k]
    fc_param = [{'params': v} for k, v in net.named_parameters() if 'fusion' in k]
    if config.pv_net.train.optim == 'Adam':
        optimizer_fc = optim.Adam(fc_param, config.pv_net.train.fc_lr,
                                  weight_decay=config.pv_net.train.weight_decay)

        optimizer_all = optim.Adam(net.parameters(), config.pv_net.train.all_lr,
                                   weight_decay=config.pv_net.train.weight_decay)
    elif config.pv_net.train.optim == 'SGD':
        optimizer_fc = optim.SGD(fc_param, config.pv_net.train.fc_lr,
                                 momentum=config.pv_net.train.momentum,
                                 weight_decay=config.pv_net.train.weight_decay)

        optimizer_all = optim.SGD(net.parameters(), config.pv_net.train.all_lr,
                                  momentum=config.pv_net.train.momentum,
                                  weight_decay=config.pv_net.train.weight_decay)
    else:
        raise NotImplementedError
    print(f'use {config.pv_net.train.optim} optimizer')

    #pc 0.001       1 epoch        e  down 0.6
    #all 0.0001     1.5 epoch      e  down 0.1

    if config.pv_net.train.resume:
        print(f'loading pretrained model from {config.pv_net.ckpt_file}')
        checkpoint = torch.load(config.pv_net.ckpt_file)
        net.module.load_state_dict(checkpoint['model'])
        optimizer_fc.load_state_dict(checkpoint['optimizer_pc'])
        optimizer_all.load_state_dict(checkpoint['optimizer_all'])
        best_prec1 = checkpoint['best_prec1']
        epoch_pc_view = checkpoint['epoch_pc_view']
        epoch_pc = checkpoint['epoch_pc']
        if config.pv_net.train.resume_epoch is not None:
            resume_epoch = config.pv_net.train.resume_epoch
        else:
            resume_epoch = checkpoint['epoch'] + 1

    lr_scheduler_fc = torch.optim.lr_scheduler.StepLR(optimizer_fc, 8, 0.3)
    lr_scheduler_all = torch.optim.lr_scheduler.StepLR(optimizer_all, 8, 0.3)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    for epoch in range(resume_epoch, config.pv_net.train.max_epoch):

        # for p in net.module.parameters():
        #     p.requires_grad = True

        #
        if epoch < 20:
            epoch_pc += 1
            lr_scheduler_fc.step(epoch=epoch_pc)
            train(train_loader, net, criterion, optimizer_fc, epoch)
        else:
            epoch_pc_view += 1
            lr_scheduler_all.step(epoch=epoch_pc_view)
            train(train_loader, net, criterion, optimizer_all, epoch)


        with torch.no_grad():
            prec1 = validate(val_loader, net, epoch)

        # save checkpoints
        if best_prec1 < prec1:
            best_prec1 = prec1
            save_ckpt(epoch, epoch_pc, epoch_pc_view, best_prec1, net, optimizer_fc, optimizer_all)

        # save_record(epoch, prec1, net.module)
        print('curr accuracy: ', prec1)
        print('best accuracy: ', best_prec1)

    print('Train Finished!')
Exemplo n.º 8
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()
    args = parse_args()

    total_batch_sz = config.pv_net.train.batch_sz * len(
        config.available_gpus.split(','))
    total_epoch = config.pv_net.train.max_epoch

    if args.gpu is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
        total_batch_sz = config.pv_net.train.batch_sz * len(
            args.gpu.split(','))
    if args.epochs is not None:
        total_epoch = args.epochs

    train_dataset = pc_view_data(config.pv_net.pc_root,
                                 config.pv_net.view_root,
                                 status=STATUS_TRAIN,
                                 base_model_name=config.base_model_name)
    val_dataset = pc_view_data(config.pv_net.pc_root,
                               config.pv_net.view_root,
                               status=STATUS_TEST,
                               base_model_name=config.base_model_name)

    train_loader = DataLoader(train_dataset,
                              batch_size=total_batch_sz,
                              num_workers=config.num_workers,
                              shuffle=True,
                              drop_last=True)
    val_loader = DataLoader(val_dataset,
                            batch_size=total_batch_sz,
                            num_workers=config.num_workers,
                            shuffle=True)

    best_prec1 = 0
    best_map = 0
    resume_epoch = 0

    epoch_pc_view = 0
    epoch_pc = 0

    # create model
    net = PVRNet()
    net = net.to(device=config.device)
    net = nn.DataParallel(net)

    # optimizer
    fc_param = [{
        'params': v
    } for k, v in net.named_parameters() if 'fusion' in k]
    if config.pv_net.train.optim == 'Adam':
        optimizer_fc = optim.Adam(
            fc_param,
            config.pv_net.train.fc_lr,
            weight_decay=config.pv_net.train.weight_decay)

        optimizer_all = optim.Adam(
            net.parameters(),
            config.pv_net.train.all_lr,
            weight_decay=config.pv_net.train.weight_decay)
    elif config.pv_net.train.optim == 'SGD':
        optimizer_fc = optim.SGD(fc_param,
                                 config.pv_net.train.fc_lr,
                                 momentum=config.pv_net.train.momentum,
                                 weight_decay=config.pv_net.train.weight_decay)

        optimizer_all = optim.SGD(
            net.parameters(),
            config.pv_net.train.all_lr,
            momentum=config.pv_net.train.momentum,
            weight_decay=config.pv_net.train.weight_decay)
    else:
        raise NotImplementedError
    print(f'use {config.pv_net.train.optim} optimizer')
    print(f'Sclae:{net.module.n_scale} ')

    if config.pv_net.train.resume:
        print(f'loading pretrained model from {config.pv_net.ckpt_file}')
        checkpoint = torch.load(config.pv_net.ckpt_file)
        state_dict = checkpoint['model']
        net.module.load_state_dict(checkpoint['model'])
        optimizer_fc.load_state_dict(checkpoint['optimizer_pc'])
        optimizer_all.load_state_dict(checkpoint['optimizer_all'])
        best_prec1 = checkpoint['best_prec1']
        epoch_pc_view = checkpoint['epoch_all']
        epoch_pc = checkpoint['epoch_pc']
        if config.pv_net.train.resume_epoch is not None:
            resume_epoch = config.pv_net.train.resume_epoch
        else:
            resume_epoch = max(checkpoint['epoch_pc'], checkpoint['epoch_all'])

    if config.pv_net.train.iter_train == False:
        print('No iter')
        lr_scheduler_fc = torch.optim.lr_scheduler.StepLR(optimizer_fc, 5, 0.3)
        lr_scheduler_all = torch.optim.lr_scheduler.StepLR(
            optimizer_all, 5, 0.3)
    else:
        print('iter')
        lr_scheduler_fc = torch.optim.lr_scheduler.StepLR(optimizer_fc, 6, 0.3)
        lr_scheduler_all = torch.optim.lr_scheduler.StepLR(
            optimizer_all, 6, 0.3)

    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    for epoch in range(resume_epoch, total_epoch):

        if config.pv_net.train.iter_train == True:
            if epoch < 12:
                lr_scheduler_fc.step(epoch=epoch_pc)
                print(lr_scheduler_fc.get_lr())

                if (epoch_pc + 1) % 3 == 0:
                    print('train score block')
                    for m in net.module.parameters():
                        m.reqires_grad = False
                    net.module.fusion_conv1.requires_grad = True
                else:
                    print('train all fc block')
                    for m in net.module.parameters():
                        m.reqires_grad = True

                train(train_loader, net, criterion, optimizer_fc, epoch)
                epoch_pc += 1

            else:
                lr_scheduler_all.step(epoch=epoch_pc_view)
                print(lr_scheduler_all.get_lr())

                if (epoch_pc_view + 1) % 3 == 0:
                    print('train score block')
                    for m in net.module.parameters():
                        m.reqires_grad = False
                    net.module.fusion_conv1.requires_grad = True
                else:
                    print('train all block')
                    for m in net.module.parameters():
                        m.reqires_grad = True

                train(train_loader, net, criterion, optimizer_all, epoch)
                epoch_pc_view += 1

        else:
            if epoch < 10:
                lr_scheduler_fc.step(epoch=epoch_pc)
                print(lr_scheduler_fc.get_lr())
                train(train_loader, net, criterion, optimizer_fc, epoch)
                epoch_pc += 1

            else:
                lr_scheduler_all.step(epoch=epoch_pc_view)
                print(lr_scheduler_all.get_lr())
                train(train_loader, net, criterion, optimizer_all, epoch)
                epoch_pc_view += 1

        with torch.no_grad():
            prec1, retrieval_map = validate(val_loader, net, epoch)

        # save checkpoints
        if best_prec1 < prec1:
            best_prec1 = prec1
            save_ckpt(epoch, epoch_pc, epoch_pc_view, best_prec1, net,
                      optimizer_fc, optimizer_all)
        if best_map < retrieval_map:
            best_map = retrieval_map

        print('curr accuracy: ', prec1)
        print('best accuracy: ', best_prec1)
        print('best map: ', best_map)

    print('Train Finished!')
Exemplo n.º 9
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()

    train_dataset = data_pth.view_data(config.view_net.data_root,
                                       status=STATUS_TRAIN,
                                       base_model_name=config.base_model_name)
    val_dataset = data_pth.view_data(config.view_net.data_root,
                                     status=STATUS_TEST,
                                     base_model_name=config.base_model_name)

    train_loader = DataLoader(train_dataset, batch_size=config.view_net.train.batch_sz,
                              num_workers=config.num_workers,shuffle = True)
    val_loader = DataLoader(val_dataset, batch_size=config.view_net.train.batch_sz,
                            num_workers=config.num_workers,shuffle=True)

    best_prec1 = 0
    resume_epoch = 0
    # create model
    net = MVCNN()
    net = net.to(device=config.device)
    net = nn.DataParallel(net)
    optimizer = optim.SGD(net.parameters(), config.view_net.train.lr,
                          momentum=config.view_net.train.momentum,
                          weight_decay=config.view_net.train.weight_decay)
    # optimizer = optim.Adam(net.parameters(), config.view_net.train.lr,
    #                        weight_decay=config.view_net.train.weight_decay)

    if config.view_net.train.resume:
        print(f'loading pretrained model from {config.view_net.ckpt_file}')
        checkpoint = torch.load(config.view_net.ckpt_file)
        net.module.load_state_dict({k[7:]: v for k, v in checkpoint['model'].items()})
        # net.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        best_prec1 = checkpoint['best_prec1']
        if config.view_net.train.resume_epoch is not None:
            resume_epoch = config.view_net.train.resume_epoch
        else:
            resume_epoch = checkpoint['epoch'] + 1

    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 5, 0.5)
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.to(device=config.device)

    # for p in net.module.feature.parameters():
    #     p.requires_grad = False

    for epoch in range(resume_epoch, config.view_net.train.max_epoch):
        if epoch >= 5:
            for p in net.parameters():
                p.requires_grad = True
        lr_scheduler.step(epoch=epoch)

        train(train_loader, net, criterion, optimizer, epoch)

        with torch.no_grad():
            prec1 = validate(val_loader, net, epoch)

        # save checkpoints
        if best_prec1 < prec1:
            best_prec1 = prec1
            save_ckpt(epoch, best_prec1, net, optimizer)

        save_record(epoch, prec1, net.module)
        print('curr accuracy: ', prec1)
        print('best accuracy: ', best_prec1)

    print('Train Finished!')
Exemplo n.º 10
0
def main():
    print('Training Process\nInitializing...\n')
    config.init_env()
    args = parse_args()

    total_batch_sz = config.pvd_net.train.batch_sz * len(config.available_gpus.split(','))
    total_epoch = config.pvd_net.train.max_epoch

    if args.gpu is not None:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
        total_batch_sz = config.pvd_net.train.batch_sz * len(args.gpu.split(','))
    if args.epochs is not None:
        total_epoch = args.epochs

    train_dataset = pc_view_dp_data(config.pvd_net.pc_root,
                                    config.pvd_net.view_root,
                                    config.pvd_net.depth_root,
                                    status=STATUS_TRAIN,
                                    base_model_name=config.base_model_name)
    val_dataset = pc_view_dp_data(config.pvd_net.pc_root,
                                  config.pvd_net.view_root,
                                  config.pvd_net.depth_root,
                                  status=STATUS_TEST,
                                  base_model_name=config.base_model_name)

    train_loader = DataLoader(train_dataset, batch_size=total_batch_sz,
                              num_workers=config.num_workers, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=total_batch_sz,
                            num_workers=config.num_workers, shuffle=True)

    best_prec1 = 0
    best_map = 0
    resume_epoch = 0

    epoch_pc_view_dp = 0
    epoch_pc = 0

    # create model
    net = PVRNet()
    net = net.to(device=config.device)
    net = nn.DataParallel(net)

    # optimizer
    g_params = [{'params': v} for k, v in net.named_parameters() if 'GEN' in k]
    en_params = [{'params': v} for k, v in net.named_parameters() if ('GEN_EN' or 'CLS') in k]
    dis_params = [{'params': v} for k, v in net.named_parameters() if 'DIS' in k]
    f_params = [{'params': v} for k, v in net.named_parameters() if not ('GEN' or 'DIS' or 'CLS') in k]
    if config.pvd_net.train.optim == 'Adam':
        b1 = 0.5
        b2 = 0.999

        optimizer_G = torch.optim.Adam(g_params, weight_decay=config.pvd_net.train.weight_decay,
                                       lr=config.pvd_net.train.all_lr, betas=(b1, b2))
        optimizer_D = torch.optim.Adam(dis_params, lr=config.pvd_net.train.all_lr, betas=(b1, b2),
                                       weight_decay=config.pvd_net.train.weight_decay)
        optimizer_E = torch.optim.Adam(en_params, lr=config.pvd_net.train.all_lr, betas=(b1, b2),
                                       weight_decay=config.pvd_net.train.weight_decay)
        optimizer_F = torch.optim.Adam(f_params, lr=config.pvd_net.train.all_lr, betas=(b1, b2),
                                       weight_decay=config.pvd_net.train.weight_decay)
        optimizer = [optimizer_G, optimizer_D, optimizer_E, optimizer_F]
    else:
        raise NotImplementedError
    print(f'use {config.pvd_net.train.optim} optimizer')
    # print(f'Sclae:{net.module.n_scale} ')

    if config.pvd_net.train.resume:
        print(f'loading pretrained model from {config.pvd_net.ckpt_file}')
        checkpoint = torch.load(config.pvd_net.ckpt_file)
        state_dict = checkpoint['model']
        net.module.load_state_dict(checkpoint['model'])
        # optimizer_fc.load_state_dict(checkpoint['optimizer_pc'])
        # optimizer_all.load_state_dict(checkpoint['optimizer_all'])
        best_prec1 = checkpoint['best_prec1']
        epoch_pc_view_dp = checkpoint['epoch_all']
        epoch_pc = checkpoint['epoch_pc']
        if config.pvd_net.train.resume_epoch is not None:
            resume_epoch = config.pvd_net.train.resume_epoch
        else:
            resume_epoch = max(checkpoint['epoch_pc'], checkpoint['epoch_all'])

    if not config.pvd_net.train.iter_train:  # iter_train 在训练时进行了梯度下降迭代更新参数
        print('No iter')
        lr_scheduler_g = torch.optim.lr_scheduler.StepLR(optimizer_G, 5, 0.3)
        lr_scheduler_d = torch.optim.lr_scheduler.StepLR(optimizer_D, 5, 0.3)
        lr_scheduler_e = torch.optim.lr_scheduler.StepLR(optimizer_E, 5, 0.3)
        lr_scheduler_f = torch.optim.lr_scheduler.StepLR(optimizer_F, 5, 0.3)
    else:
        print('VCIter')
        lr_scheduler_g = torch.optim.lr_scheduler.StepLR(optimizer_G, 6, 0.3)
        lr_scheduler_d = torch.optim.lr_scheduler.StepLR(optimizer_D, 6, 0.3)
        lr_scheduler_e = torch.optim.lr_scheduler.StepLR(optimizer_E, 6, 0.3)
        lr_scheduler_f = torch.optim.lr_scheduler.StepLR(optimizer_F, 5, 0.3)
    lr_scheduler = [lr_scheduler_g, lr_scheduler_d, lr_scheduler_e, lr_scheduler_f]

    for epoch in range(resume_epoch, total_epoch):
        criterion_adv = torch.nn.BCELoss().to(device=config.device)
        criterion_pix = torch.nn.L1Loss(reduction='sum').to(device=config.device)
        criterion_cls = nn.L1Loss().to(device=config.device)
        criterion_fusion = nn.CrossEntropyLoss().to(device=config.device)
        criterion = [criterion_adv, criterion_pix, criterion_cls, criterion_fusion]

        if config.pvd_net.train.iter_train:
            train(train_loader, net, criterion, optimizer, lr_scheduler, epoch)
            # print('Generator_lr:\t' + str(lr_scheduler_g.get_lr()), '\n', 'Discriminator_lr:\t' + str(lr_scheduler_d.get_lr()), '\n', 'Adv_Classifier_lr:\t' + str(lr_scheduler_e.get_lr()), '\n',
            #       'Fusion_Classifier_lr:\t' + str(lr_scheduler_c.get_lr()))
            epoch_pc_view_dp += 1
        else:
            train(train_loader, net, criterion, optimizer, lr_scheduler, epoch)
            # print('Generator_lr:\t' + str(lr_scheduler_g.get_lr()), '\n', 'Discriminator_lr:\t' + str(lr_scheduler_d.get_lr()), '\n', 'Adv_Classifier_lr:\t' + str(lr_scheduler_e.get_lr()), '\n',
            #       'Fusion_Classifier_lr:\t' + str(lr_scheduler_c.get_lr()))
            epoch_pc_view_dp += 1

        with torch.no_grad():
            prec1, retrieval_map = validate(val_loader, net, epoch)

        # save checkpoints
        if best_prec1 < prec1:
            best_prec1 = prec1
            save_ckpt(epoch, epoch_pc, epoch_pc_view_dp, best_prec1, net, optimizer)
        if best_map < retrieval_map:
            best_map = retrieval_map

        print('curr accuracy: ', prec1)
        print('best accuracy: ', best_prec1)
        print('best map: ', best_map)

    print('Train Finished!')