示例#1
0
def main():
    global cfg, global_all_step
    global_all_step = 0
    cfg = DefaultConfig()
    args = {
        'resnet18': RESNET18_SUNRGBD_CONFIG().args(),
    }

    # Setting random seed
    # if cfg.MANUAL_SEED is None:
    #     cfg.MANUAL_SEED = random.randint(1, 10000)
    # random.seed(cfg.MANUAL_SEED)
    # torch.manual_seed(cfg.MANUAL_SEED)

    # args for different backbones
    cfg.parse(args['resnet18'])
    run_id = random.randint(1, 100000)
    summary_dir = '/home/lzy/summary/generateDepth/' + 'train_depth_1e-3_0909_' + str(
        run_id)
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    writer = SummaryWriter(summary_dir)
    cfg.LR = 0.001
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2'
    device_ids = torch.cuda.device_count()
    print('device_ids:', device_ids)
    # project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1])
    # util.mkdir('logs')
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                 std=[0.229, 0.224, 0.225])
    # normalize=transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    # data
    train_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_TRAIN,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.RandomHorizontalFlip(),
            # dataset.RandomRotate(),
            # dataset.RandomFlip(),
            # dataset.PILBrightness(0.4),
            # dataset.PILContrast(0.4),
            # dataset.PILColorBalance(0.4),
            # dataset.PILSharpness(0.4),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))

    val_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_VAL,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))
    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
    #     num_workers=4, pin_memory=True, sampler=None)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,batch_size=16, shuffle=False,
    #     num_workers=4, pin_memory=True)
    train_loader = DataProvider(cfg,
                                dataset=train_dataset,
                                batch_size=150,
                                shuffle=True)
    val_loader = DataProvider(cfg,
                              dataset=val_dataset,
                              batch_size=120,
                              shuffle=False)

    # class weights
    # num_classes_train = list(Counter([i[1] for i in train_loader.dataset.imgs]).values())
    # cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train)

    # writer = SummaryWriter(log_dir=cfg.LOG_PATH)  # tensorboard

    # net_classification_1=resnet50(pretrained=True)
    # net_classification_2=resnet50(pretrained=True)

    net_classification_2 = models.__dict__['resnet50'](num_classes=365)
    # net_classification_2=torchvision.models.resnext101_32x8d(pretrained=True, progress=True)
    # net_classification_2 = models.__dict__['resnet18'](num_classes=365)
    # net_classification_2=torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
    net_classification_2.fc = nn.Linear(2048, 19)

    # net_classification_2=torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
    # for param in net_classification_1.parameters():
    #     param.requires_grad = False
    # for param in net_classification_2.parameters():
    #     param.requires_grad = True
    # net_classification_1.fc = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(2048, 1024),nn.LeakyReLU(inplace=True),nn.Linear(1024,67))
    # net_classification_2.fc = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(2048,1024),nn.LeakyReLU(inplace=True),nn.Linear(1024,67))

    # net_classification_1.load_state_dict(torch.load("./bestmodel/best_model_resnext_16d_2048_1024_dropout_0.5_b.pkl"))、
    # net_classification_2.load_state_dict(torch.load("./bestmodel/best_model_resnext_16d_2048_1024_dropout_0.5_b.pkl"))
    # net_classification_2
    load_path = "resnet50_Depth_sunrgbd_best_0909_5e-4_.pth.tar"
    checkpoint = torch.load(load_path,
                            map_location=lambda storage, loc: storage)
    state_dict = {
        str.replace(k, 'module.', ''): v
        for k, v in checkpoint['state_dict'].items()
    }
    best_mean_depth = checkpoint['best_mean_2']
    print("load sunrgbd dataset:", best_mean_depth)
    net_classification_2.load_state_dict(state_dict)
    # print(net_classification_1)
    # num_ftrs = net_classification_1.fc.in_features
    # net_classification_1.fc = nn.Linear(num_ftrs, cfg.NUM_CLASSES)
    # num_ftrs = net_classification_2.fc.in_features
    # net_classification_2.fc = nn.Linear(num_ftrs, cfg.NUM_CLASSES)
    # net_classification_2.fc = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(2048,1024),nn.LeakyReLU(inplace=True),nn.Linear(1024,67))
    net_classification_2.fc = nn.Linear(2048, 67)
    init_weights(net_classification_2.fc, 'normal')
    print(net_classification_2)

    # net_classification_1.cuda()
    net_classification_2.cuda()
    cudnn.benchmark = True

    if cfg.GENERATE_Depth_DATA:
        print('GENERATE_Depth_DATA model set')
        cfg_generate = copy.deepcopy(cfg)
        cfg_generate.CHECKPOINTS_DIR = '/home/lzy/generateDepth/bestmodel/before_best.pth'
        cfg_generate.GENERATE_Depth_DATA = False
        cfg_generate.NO_UPSAMPLE = False
        checkpoint = torch.load(cfg_generate.CHECKPOINTS_DIR)
        model = define_TrecgNet(cfg_generate, upsample=True, generate=True)
        load_checkpoint_depth(model,
                              cfg_generate.CHECKPOINTS_DIR,
                              checkpoint,
                              data_para=True)
        generate_model = torch.nn.DataParallel(model).cuda()
        generate_model.eval()

    # net_classification_1 = torch.nn.DataParallel(net_classification_1).cuda()
    net_classification_2 = torch.nn.DataParallel(net_classification_2).cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    # best_mean_1=0
    best_mean_2 = 0
    # optimizer_2 = torch.optim.Adam(net_classification_2.parameters(), lr=cfg.LR, betas=(0.5, 0.999))
    optimizer_2 = torch.optim.SGD(net_classification_2.parameters(),
                                  lr=cfg.LR,
                                  momentum=cfg.MOMENTUM,
                                  weight_decay=cfg.WEIGHT_DECAY)

    schedulers = get_scheduler(optimizer_2, cfg, cfg.LR_POLICY)
    for epoch in range(0, 1000):
        if global_all_step > cfg.NITER_TOTAL:
            break
        # meanacc_1,meanacc_2=validate(val_loader, net_classification_1,net_classification_2,generate_model,criterion,epoch)
        # for param_group in optimizer_2.param_groups:
        #     lr_t = param_group['lr']
        #     print('/////////learning rate = %.7f' % lr_t)
        #     writer.add_scalar('LR',lr_t,global_step=epoch)
        printlr(optimizer_2, writer, epoch)

        train(train_loader, schedulers, net_classification_2, generate_model,
              criterion, optimizer_2, epoch, writer)
        meanacc_2 = validate(val_loader, net_classification_2, generate_model,
                             criterion, epoch, writer)
        # meanacc_2=validate(val_loader,net_classification_2,generate_model,criterion,epoch,writer)

        # train(train_loader,net_classification_2,generate_model,criterion,optimizer_2,epoch,writer)
        # meanacc_2=validate(val_loader,net_classification_2,generate_model,criterion,epoch,writer)

        # writer.add_image(depth_image[0])
        # save best
        # if meanacc_1>best_mean_1:
        #     best_mean_1=meanacc_1
        #     print('best_mean_color:',str(best_mean_1))
        #     save_checkpoint({
        #         'epoch': epoch,
        #         'arch': cfg.ARCH,
        #         'state_dict': net_classification_1.state_dict(),
        #         'best_mean_1': best_mean_1,
        #         'optimizer' : optimizer_1.state_dict(),
        #     },CorD=True)

        if meanacc_2 > best_mean_2:
            best_mean_2 = meanacc_2
            print('best_mean_depth:', str(best_mean_2))
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': cfg.ARCH,
                    'state_dict': net_classification_2.state_dict(),
                    'best_mean_2': best_mean_2,
                    'optimizer': optimizer_2.state_dict(),
                },
                CorD=False)
        # print('best_mean_color:',str(best_mean_1))
        # writer.add_scalar('mean_acc_color', meanacc_1, global_step=epoch)
        writer.add_scalar('mean_acc_depth', meanacc_2, global_step=epoch)
        # writer.add_scalar('best_meanacc_color', best_mean_1, global_step=epoch)
        writer.add_scalar('best_meanacc_depth', best_mean_2, global_step=epoch)

    writer.close()
def main():
    global cfg
    cfg = DefaultConfig()
    args = {
        'resnet18': RESNET18_SUNRGBD_CONFIG().args(),
    }

    # Setting random seed
    # if cfg.MANUAL_SEED is None:
    #     cfg.MANUAL_SEED = random.randint(1, 10000)
    # random.seed(cfg.MANUAL_SEED)
    # torch.manual_seed(cfg.MANUAL_SEED)

    # args for different backbones
    cfg.parse(args['resnet18'])
    run_id = random.randint(1, 100000)
    summary_dir = '/home/lzy/summary/generateDepth/' + 'train_rgb_' + str(
        run_id)
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    writer = SummaryWriter(summary_dir)
    cfg.LR = 0.0001
    os.environ["CUDA_VISIBLE_DEVICES"] = '1,2'
    device_ids = torch.cuda.device_count()
    print('device_ids:', device_ids)
    # project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1])
    # util.mkdir('logs')
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                 std=[0.229, 0.224, 0.225])
    # normalize=transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    # data
    train_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_TRAIN,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.RandomHorizontalFlip(),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))

    val_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_VAL,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))
    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
    #     num_workers=4, pin_memory=True, sampler=None)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,batch_size=16, shuffle=False,
    #     num_workers=4, pin_memory=True)
    train_loader = DataProvider(cfg,
                                dataset=train_dataset,
                                batch_size=40,
                                shuffle=True)
    val_loader = DataProvider(cfg,
                              dataset=val_dataset,
                              batch_size=10,
                              shuffle=False)

    # class weights
    # num_classes_train = list(Counter([i[1] for i in train_loader.dataset.imgs]).values())
    # cfg.CLASS_WEIGHTS_TRAIN = torch.FloatTensor(num_classes_train)

    # writer = SummaryWriter(log_dir=cfg.LOG_PATH)  # tensorboard

    # net_classification_1=resnet50(pretrained=True)
    # net_classification_2=resnet50(pretrained=True)

    # net_classification_1 = models.__dict__['resnet18'](num_classes=365)
    # net_classification_2 = models.__dict__['resnet18'](num_classes=365)
    net_classification_1 = torch.hub.load('facebookresearch/WSL-Images',
                                          'resnext101_32x16d_wsl')
    # net_classification_2=torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
    # for param in net_classification_1.parameters():
    #     param.requires_grad = False
    # for param in net_classification_2.parameters():
    #     param.requires_grad = True
    net_classification_1.fc = nn.Sequential(nn.Dropout(p=0.5),
                                            nn.Linear(2048, 1024),
                                            nn.LeakyReLU(inplace=True),
                                            nn.Linear(1024, 67))
    # net_classification_2.fc = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(2048, 1024),nn.LeakyReLU(inplace=True),nn.Linear(1024,67))

    net_classification_1.load_state_dict(
        torch.load(
            "./bestmodel/best_model_resnext_16d_2048_1024_dropout_0.5_b.pkl"))
    # net_classification_2.load_state_dict(torch.load("./bestmodel/best_model_resnext_16d_2048_1024_dropout_0.5_b.pkl"))
    # net_classification_2
    # load_path = "/home/dudapeng/workspace/pretrained/place/resnet18_places365.pth"
    # checkpoint = torch.load(load_path, map_location=lambda storage, loc: storage)
    # state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()}
    # net_classification_1.load_state_dict(state_dict)
    # net_classification_2.load_state_dict(state_dict)
    print(net_classification_1)

    # num_ftrs = net_classification_1.fc.in_features
    # net_classification_1.fc = nn.Linear(num_ftrs, cfg.NUM_CLASSES)
    # num_ftrs = net_classification_2.fc.in_features
    # net_classification_2.fc = nn.Linear(num_ftrs, cfg.NUM_CLASSES)

    net_classification_1.cuda()
    # net_classification_2.cuda()
    cudnn.benchmark = True

    # if cfg.GENERATE_Depth_DATA:
    #     print('GENERATE_Depth_DATA model set')
    #     cfg_generate = copy.deepcopy(cfg)
    #     cfg_generate.CHECKPOINTS_DIR='/home/lzy/generateDepth/checkpoints/best_AtoB/trecg_AtoB_best.pth'
    #     cfg_generate.GENERATE_Depth_DATA = False
    #     cfg_generate.NO_UPSAMPLE = False
    #     checkpoint = torch.load(cfg_generate.CHECKPOINTS_DIR)
    #     model = define_TrecgNet(cfg_generate, upsample=True,generate=True)
    #     load_checkpoint_depth(model,cfg_generate.CHECKPOINTS_DIR, checkpoint, data_para=True)
    #     generate_model = torch.nn.DataParallel(model).cuda()
    #     generate_model.eval()

    net_classification_1 = torch.nn.DataParallel(net_classification_1).cuda()
    # net_classification_2 = torch.nn.DataParallel(net_classification_2).cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    best_mean_1 = 0
    # best_mean_2=0
    # optimizer = optim.SGD(model_ft.parameters(), lr=0.05,momentum=0.9)#,weight_decay=0.00005)
    optimizer_1 = torch.optim.SGD(net_classification_1.parameters(),
                                  lr=cfg.LR,
                                  momentum=cfg.MOMENTUM,
                                  weight_decay=cfg.WEIGHT_DECAY)
    # optimizer_2 = torch.optim.SGD(net_classification_2.parameters(),lr=cfg.LR,momentum=cfg.MOMENTUM,weight_decay=cfg.WEIGHT_DECAY)
    for epoch in range(0, 100):
        adjust_learning_rate(optimizer_1, epoch)
        # meanacc_1,meanacc_2=validate(val_loader, net_classification_1,net_classification_2,generate_model,criterion,epoch)
        #
        train(train_loader, net_classification_1, criterion, optimizer_1,
              epoch, writer)
        meanacc_1 = validate(val_loader, net_classification_1, criterion,
                             epoch, writer)
        # meanacc_2=validate(val_loader,net_classification_2,generate_model,criterion,epoch,writer)

        # train(train_loader,net_classification_2,generate_model,criterion,optimizer_2,epoch,writer)
        # meanacc_2=validate(val_loader,net_classification_2,generate_model,criterion,epoch,writer)

        # writer.add_image(depth_image[0])
        # save best
        if meanacc_1 > best_mean_1:
            best_mean_1 = meanacc_1
            print('best_mean_color:', str(best_mean_1))
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': cfg.ARCH,
                    'state_dict': net_classification_1.state_dict(),
                    'best_mean_1': best_mean_1,
                    'optimizer': optimizer_1.state_dict(),
                },
                CorD=True)

        # if meanacc_2>best_mean_2:
        #     best_mean_2=meanacc_2
        #     print('best_mean_depth:',str(best_mean_2))
        #     save_checkpoint({
        #         'epoch': epoch,
        #         'arch': cfg.ARCH,
        #         'state_dict': net_classification_2.state_dict(),
        #         'best_mean_2': best_mean_2,
        #         'optimizer' : optimizer_2.state_dict(),
        #     },CorD=False)
        print('best_mean_color:', str(best_mean_1))
        writer.add_scalar('mean_acc_color', meanacc_1, global_step=epoch)
        # writer.add_scalar('mean_acc_depth', meanacc_2, global_step=epoch)
        writer.add_scalar('best_meanacc_color', best_mean_1, global_step=epoch)
        # writer.add_scalar('best_meanacc_depth', best_mean_2, global_step=epoch)

    writer.close()
示例#3
0
def main():
    global cfg
    cfg = DefaultConfig()
    args = {
        'resnet18': RESNET18_SUNRGBD_CONFIG().args(),
    }

    # args for different backbones
    cfg.parse(args['resnet18'])
    cfg.LR = 1e-4
    cfg.EPOCHS = 200
    # print('cfg.EPOCHS:',cfg.EPOCHS)
    os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'

    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                 std=[0.229, 0.224, 0.225])
    # dataset.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    # data
    train_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_TRAIN,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.RandomHorizontalFlip(),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))

    val_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_VAL,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))
    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
    #     num_workers=4, pin_memory=True, sampler=None)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,batch_size=cfg.BATCH_SIZE, shuffle=False,
    #     num_workers=4, pin_memory=True)
    train_loader = DataProvider(cfg,
                                dataset=train_dataset,
                                batch_size=20,
                                shuffle=True)
    val_loader = DataProvider(cfg,
                              dataset=val_dataset,
                              batch_size=5,
                              shuffle=False)

    run_id = random.randint(1, 100000)
    summary_dir = '/home/lzy/summary/generateDepth/' + 'finetuning_nofix_' + str(
        run_id)
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    writer = SummaryWriter(summary_dir)

    if cfg.GENERATE_Depth_DATA:
        print('GENERATE_Depth_DATA model set')
        cfg_generate = copy.deepcopy(cfg)

        cfg_generate.CHECKPOINTS_DIR = '/home/lzy/generateDepth/checkpoints/best_AtoB/trecg_AtoB_best.pth'
        cfg_generate.GENERATE_Depth_DATA = False
        cfg_generate.NO_UPSAMPLE = False
        checkpoint = torch.load(cfg_generate.CHECKPOINTS_DIR)
        model = define_TrecgNet(cfg_generate, upsample=True, generate=True)
        load_checkpoint_depth(model,
                              cfg_generate.CHECKPOINTS_DIR,
                              checkpoint,
                              data_para=True)
        generate_model = torch.nn.DataParallel(model).cuda()
        generate_model.eval()

    model = ReD_Model(cfg)
    policies = model.get_optim_policies()
    model = torch.nn.DataParallel(model).cuda()
    criterion = nn.CrossEntropyLoss().cuda()
    best_mean = 0
    # optimizer= torch.optim.Adam(policies, cfg.LR,
    #                         weight_decay=cfg.WEIGHT_DECAY)
    optimizer = torch.optim.SGD(policies,
                                cfg.LR,
                                momentum=cfg.MOMENTUM,
                                weight_decay=cfg.WEIGHT_DECAY)

    for epoch in range(cfg.START_EPOCH, cfg.EPOCHS + 1):
        adjust_learning_rate(optimizer, epoch)
        # mean_acc=validate(val_loader,model,criterion,generate_model,epoch,writer)

        train(train_loader, model, criterion, generate_model, optimizer, epoch,
              writer)
        mean_acc = validate(val_loader, model, criterion, generate_model,
                            epoch, writer)
        if mean_acc > best_mean:
            best_mean = mean_acc
            print('best mean accuracy:', best_mean)
        else:
            print('best mean accuracy:', best_mean)
        writer.add_scalar('mean_acc_color', mean_acc, global_step=epoch)
        writer.add_scalar('best_meanacc_color', best_mean, global_step=epoch)
    writer.close()
def main():
    global cfg, global_all_step
    global_all_step = 0
    cfg = DefaultConfig()
    args = {
        'resnet18': RESNET18_SUNRGBD_CONFIG().args(),
    }

    # Setting random seed
    # if cfg.MANUAL_SEED is None:
    #     cfg.MANUAL_SEED = random.randint(1, 10000)
    # random.seed(cfg.MANUAL_SEED)
    # torch.manual_seed(cfg.MANUAL_SEED)

    # args for different backbones
    cfg.parse(args['resnet18'])
    run_id = random.randint(1, 100000)
    summary_dir = '/home/lzy/summary/generateDepth/' + 'resnet152_train_rgb_place365_2e-3_' + str(
        run_id)
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    writer = SummaryWriter(summary_dir)
    cfg.LR = 0.002
    os.environ["CUDA_VISIBLE_DEVICES"] = '2'
    device_ids = torch.cuda.device_count()
    print('device_ids:', device_ids)
    # project_name = reduce(lambda x, y: str(x) + '/' + str(y), os.path.realpath(__file__).split(os.sep)[:-1])
    # util.mkdir('logs')
    # normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                                 std=[0.229, 0.224, 0.225])
    # normalize=transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    # data
    train_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_TRAIN,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.RandomCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.RandomHorizontalFlip(),
            # dataset.RandomRotate(),
            # dataset.RandomFlip(),
            # dataset.PILBrightness(0.4),
            # dataset.PILContrast(0.4),
            # dataset.PILColorBalance(0.4),
            # dataset.PILSharpness(0.4),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))

    val_dataset = dataset.SingleDataset(
        cfg,
        data_dir=cfg.DATA_DIR_VAL,
        transform=transforms.Compose([
            dataset.Resize((cfg.LOAD_SIZE, cfg.LOAD_SIZE)),
            dataset.CenterCrop((cfg.FINE_SIZE, cfg.FINE_SIZE)),
            dataset.ToTensor(),
            dataset.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
        ]))
    # train_loader = torch.utils.data.DataLoader(
    #     train_dataset, batch_size=cfg.BATCH_SIZE, shuffle=True,
    #     num_workers=4, pin_memory=True, sampler=None)
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset,batch_size=16, shuffle=False,
    #     num_workers=4, pin_memory=True)
    train_loader = DataProvider(cfg,
                                dataset=train_dataset,
                                batch_size=30,
                                shuffle=True)
    val_loader = DataProvider(cfg,
                              dataset=val_dataset,
                              batch_size=10,
                              shuffle=False)

    # class weights
    net_classification_1 = models.__dict__['resnet152'](num_classes=365)
    load_path = "./bestmodel/resnet152_places365.pth"
    checkpoint = torch.load(load_path,
                            map_location=lambda storage, loc: storage)
    weight = []
    for k, v in checkpoint.items():
        weight.append(v)
    i = 0
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k in net_classification_1.state_dict():
        # print(k)
        new_state_dict[k] = weight[i]
        i += 1
    # state_dict = {str.replace(k, 'module.', ''): v for k, v in checkpoint['state_dict'].items()}
    # state_dict = {str.replace(k, 'norm.', 'norm'): v for k, v in state_dict.items()}
    # state_dict = {str.replace(k, 'conv.', 'conv'): v for k, v in state_dict.items()}
    # state_dict = {str.replace(k, 'normweight', 'norm.weight'): v for k, v in state_dict.items()}
    # state_dict = {str.replace(k, 'normbias', 'norm.bias'): v for k, v in state_dict.items()}
    # state_dict = {str.replace(k, 'normrunning_var', 'norm.running_var'): v for k, v in state_dict.items()}
    # state_dict = {str.replace(k, 'convweight', 'conv.weight'): v for k, v in state_dict.items()}
    # state_dict = {str.replace(k, 'normrunning_mean', 'norm.running_mean'): v for k, v in state_dict.items()}

    # normweight
    # state_dict = {str.replace(k, 'conv.', 'conv'): v for k, v in state_dict.items()}

    # best_mean_depth = checkpoint['best_mean_2']
    # print("load sunrgbd dataset:",best_mean_depth)
    net_classification_1.load_state_dict(new_state_dict)
    # net_classification_1=torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x16d_wsl')
    # net_classification_1.fc = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(2048, 1024),nn.LeakyReLU(inplace=True),nn.Linear(1024,67))
    # net_classification_1.fc = nn.Sequential(nn.Dropout(p=0.2), nn.Linear(2048, 1024), nn.LeakyReLU(inplace=True), nn.Dropout(p=0.2),nn.Linear(1024, 67))
    # net_classification_1.fc = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(1024, 67))
    # net_classification_1.classifier= nn.Linear(2208, 67)
    net_classification_1.fc = nn.Linear(2048, 67)

    # init_weights(net_classification_1.classifier, 'normal')
    init_weights(net_classification_1.fc, 'normal')

    # net_classification_1.load_state_dict(torch.load("./bestmodel/best_model_resnext_16d_2048_1024_dropout_0.5_b.pkl"))

    # for param in net_classification_1.parameters():
    # 	param.requires_grad = False
    # num_ftrs = net_classification_1.fc.in_features
    print(net_classification_1)
    net_classification_1.cuda()
    # net_classification_2.cuda()
    cudnn.benchmark = True

    net_classification_1 = torch.nn.DataParallel(net_classification_1).cuda()
    # net_classification_2 = torch.nn.DataParallel(net_classification_2).cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    best_mean_1 = 0
    # best_mean_2=0
    # optimizer_1 = torch.optim.Adam(net_classification_1.parameters(), lr=cfg.LR, betas=(0.5, 0.999))
    optimizer_1 = torch.optim.SGD(net_classification_1.parameters(),
                                  lr=cfg.LR,
                                  momentum=cfg.MOMENTUM,
                                  weight_decay=cfg.WEIGHT_DECAY)
    schedulers = get_scheduler(optimizer_1, cfg, cfg.LR_POLICY)
    # optimizer_1 = torch.optim.SGD(net_classification_1.parameters(),lr=cfg.LR,momentum=cfg.MOMENTUM,weight_decay=cfg.WEIGHT_DECAY)
    # optimizer_2 = torch.optim.SGD(net_classification_2.parameters(),lr=cfg.LR,momentum=cfg.MOMENTUM,weight_decay=cfg.WEIGHT_DECAY)
    for epoch in range(0, 1000):
        if global_all_step > cfg.NITER_TOTAL:
            break
        # meanacc_1,meanacc_2=validate(val_loader, net_classification_1,net_classification_2,generate_model,criterion,epoch)
        printlr(optimizer_1, writer, epoch)
        train(train_loader, schedulers, net_classification_1, criterion,
              optimizer_1, epoch, writer)
        meanacc_1 = validate(val_loader, net_classification_1, criterion,
                             epoch, writer)

        # writer.add_image(depth_image[0])
        # save best
        if meanacc_1 > best_mean_1:
            best_mean_1 = meanacc_1
            print('best_mean_color:', str(best_mean_1))
            save_checkpoint(
                {
                    'epoch': epoch,
                    'arch': cfg.ARCH,
                    'state_dict': net_classification_1.state_dict(),
                    'best_mean_1': best_mean_1,
                    'optimizer': optimizer_1.state_dict(),
                },
                CorD=True)

        print('best_mean_color:', str(best_mean_1))
        writer.add_scalar('mean_acc_color', meanacc_1, global_step=epoch)
        # writer.add_scalar('mean_acc_depth', meanacc_2, global_step=epoch)
        writer.add_scalar('best_meanacc_color', best_mean_1, global_step=epoch)
        # writer.add_scalar('best_meanacc_depth', best_mean_2, global_step=epoch)

    writer.close()