Esempio n. 1
0
def main():
    global args, best_prec1, best_test_prec1, cond_best_test_prec1, best_cluster_acc, best_cluster_acc_2
    
    # define model
    model = Model_Construct(args)
    print(model)
    model = torch.nn.DataParallel(model).cuda() # define multiple GPUs
    
    # define learnable cluster centers
    learn_cen = Variable(torch.cuda.FloatTensor(args.num_classes, 2048).fill_(0))
    learn_cen.requires_grad_(True)
    learn_cen_2 = Variable(torch.cuda.FloatTensor(args.num_classes, args.num_neurons * 4).fill_(0))
    learn_cen_2.requires_grad_(True)

    # define loss function/criterion and optimizer
    criterion = torch.nn.CrossEntropyLoss().cuda()
    criterion_cons = ConsensusLoss(nClass=args.num_classes, div=args.div).cuda()
    
    np.random.seed(1)  # may fix test data
    random.seed(1)
    torch.manual_seed(1)
    
    # apply different learning rates to different layer
    optimizer = torch.optim.SGD([
            {'params': model.module.conv1.parameters(), 'name': 'conv'},
            {'params': model.module.bn1.parameters(), 'name': 'conv'},
            {'params': model.module.layer1.parameters(), 'name': 'conv'},
            {'params': model.module.layer2.parameters(), 'name': 'conv'},
            {'params': model.module.layer3.parameters(), 'name': 'conv'},
            {'params': model.module.layer4.parameters(), 'name': 'conv'},
            {'params': model.module.fc1.parameters(), 'name': 'ca_cl'},
            {'params': model.module.fc2.parameters(), 'name': 'ca_cl'},
            {'params': learn_cen, 'name': 'conv'},
            {'params': learn_cen_2, 'name': 'conv'}
        ],
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay, 
                                    nesterov=args.nesterov)
    
    # resume
    epoch = 0                                
    init_state_dict = model.state_dict()
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoints '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            best_test_prec1 = checkpoint['best_test_prec1']
            cond_best_test_prec1 = checkpoint['cond_best_test_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            learn_cen = checkpoint['learn_cen']
            learn_cen_2 = checkpoint['learn_cen_2']
            print("==> loaded checkpoint '{}'(epoch {})".format(args.resume, checkpoint['epoch']))
        else:
            raise ValueError('The file to be resumed from does not exist!', args.resume)
    
    # make log directory
    if not os.path.isdir(args.log):
        os.makedirs(args.log)
    log = open(os.path.join(args.log, 'log.txt'), 'a')
    state = {k: v for k, v in args._get_kwargs()}
    log.write(json.dumps(state) + '\n')
    log.close()

    # start time
    log = open(os.path.join(args.log, 'log.txt'), 'a')
    log.write('\n-------------------------------------------\n')
    log.write(time.asctime(time.localtime(time.time())))
    log.write('\n-------------------------------------------')
    log.close()

    cudnn.benchmark = True
    
    # process data and prepare dataloaders
    train_loader_source, train_loader_target, val_loader_target, val_loader_target_t, val_loader_source = generate_dataloader(args)
    train_loader_target.dataset.tgts = list(np.array(torch.LongTensor(train_loader_target.dataset.tgts).fill_(-1))) # avoid using ground truth labels of target

    print('begin training')
    batch_number = count_epoch_on_large_dataset(train_loader_target, train_loader_source, args)
    num_itern_total = args.epochs * batch_number

    new_epoch_flag = False # if new epoch, new_epoch_flag=True
    test_flag = False # if test, test_flag=True
    
    src_cs = torch.cuda.FloatTensor(len(train_loader_source.dataset.tgts)).fill_(1) # initialize source weights
    
    count_itern_each_epoch = 0
    for itern in range(epoch * batch_number, num_itern_total):
        # evaluate on the target training and test data
        if (itern == 0) or (count_itern_each_epoch == batch_number):
            prec1, c_s, c_s_2, c_t, c_t_2, c_srctar, c_srctar_2, source_features, source_features_2, source_targets, target_features, target_features_2, target_targets, pseudo_labels = validate_compute_cen(val_loader_target, val_loader_source, model, criterion, epoch, args)
            test_acc = validate(val_loader_target_t, model, criterion, epoch, args)
            test_flag = True
            
            # K-means clustering or its variants
            if ((itern == 0) and args.src_cen_first) or (args.initial_cluster == 2):
                cen = c_s
                cen_2 = c_s_2
            else:
                cen = c_t
                cen_2 = c_t_2
            if (itern != 0) and (args.initial_cluster != 0) and (args.cluster_method == 'kernel_kmeans'):
                cluster_acc, c_t = kernel_k_means(target_features, target_targets, pseudo_labels, train_loader_target, epoch, model, args, best_cluster_acc)
                cluster_acc_2, c_t_2 = kernel_k_means(target_features_2, target_targets, pseudo_labels, train_loader_target, epoch, model, args, best_cluster_acc_2, change_target=False)
            elif args.cluster_method != 'spherical_kmeans':
                cluster_acc, c_t = k_means(target_features, target_targets, train_loader_target, epoch, model, cen, args, best_cluster_acc)
                cluster_acc_2, c_t_2 = k_means(target_features_2, target_targets, train_loader_target, epoch, model, cen_2, args, best_cluster_acc_2, change_target=False)
            elif args.cluster_method == 'spherical_kmeans':
                cluster_acc, c_t = spherical_k_means(target_features, target_targets, train_loader_target, epoch, model, cen, args, best_cluster_acc)
                cluster_acc_2, c_t_2 = spherical_k_means(target_features_2, target_targets, train_loader_target, epoch, model, cen_2, args, best_cluster_acc_2, change_target=False)
            
            # record the best accuracy of K-means clustering
            log = open(os.path.join(args.log, 'log.txt'), 'a')
            if cluster_acc != best_cluster_acc:
                best_cluster_acc = cluster_acc
                log.write('\n                                                          best_cluster acc: %3f' % best_cluster_acc)
            if cluster_acc_2 != best_cluster_acc_2:
                best_cluster_acc_2 = cluster_acc_2
                log.write('\n                                                          best_cluster_2 acc: %3f' % best_cluster_acc_2)
            log.close()
            
            # re-initialize learnable cluster centers
            if args.init_cen_on_st:
                cen = (c_t + c_s) / 2# or c_srctar
                cen_2 = (c_t_2 + c_s_2) / 2# or c_srctar_2
            else:
                cen = c_t
                cen_2 = c_t_2
            #if itern == 0:
            learn_cen.data = cen.data.clone()
            learn_cen_2.data = cen_2.data.clone()
            
            # select source samples
            if (itern != 0) and (args.src_soft_select or args.src_hard_select):
                src_cs = source_select(source_features, source_targets, target_features, pseudo_labels, train_loader_source, epoch, c_t.data.clone(), args)
            
            # use source pre-trained model to extract features for first clustering
            if (itern == 0) and args.src_pretr_first: 
                model.load_state_dict(init_state_dict)
                
            if itern != 0:
                count_itern_each_epoch = 0
                epoch += 1
            batch_number = count_epoch_on_large_dataset(train_loader_target, train_loader_source, args)
            train_loader_target_batch = enumerate(train_loader_target)
            train_loader_source_batch = enumerate(train_loader_source)
            
            new_epoch_flag = True
            
            del source_features
            del source_features_2
            del source_targets
            del target_features
            del target_features_2
            del target_targets
            del pseudo_labels
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
        elif (args.src.find('visda') != -1) and (itern % int(num_itern_total / 200) == 0):
            prec1, _, _, _, _, _, _, _, _, _, _, _, _, _ = validate_compute_cen(val_loader_target, val_loader_source, model, criterion, epoch, args, compute_cen=False)
            test_acc = validate(val_loader_target_t, model, criterion, epoch, args)
            test_flag = True
        if test_flag:
            # record the best prec1 and save checkpoint
            log = open(os.path.join(args.log, 'log.txt'), 'a')
            if prec1 > best_prec1:
                best_prec1 = prec1
                cond_best_test_prec1 = 0
                log.write('\n                                                                                 best val acc till now: %3f' % best_prec1)
            if test_acc > best_test_prec1:
                best_test_prec1 = test_acc
                log.write('\n                                                                                 best test acc till now: %3f' % best_test_prec1)
            ipdb.set_trace()
            is_cond_best = ((prec1 == best_prec1) and (test_acc > cond_best_test_prec1))
            if is_cond_best:
                cond_best_test_prec1 = test_acc
                log.write('\n                                                                                 cond best test acc till now: %3f' % cond_best_test_prec1)
            log.close()
            save_checkpoint({
                'epoch': epoch,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'learn_cen': learn_cen,
                'learn_cen_2': learn_cen_2,
                'best_prec1': best_prec1,
                'best_test_prec1': best_test_prec1,
                'cond_best_test_prec1': cond_best_test_prec1,
            }, is_cond_best, args)
            
            test_flag = False
        
        # early stop
        if epoch > args.stop_epoch:
                break

        # train for one iteration
        train_loader_source_batch, train_loader_target_batch = train(train_loader_source, train_loader_source_batch, train_loader_target, train_loader_target_batch, model, learn_cen, learn_cen_2, criterion_cons, optimizer, itern, epoch, new_epoch_flag, src_cs, args)

        model = model.cuda()
        new_epoch_flag = False
        count_itern_each_epoch += 1
    
    log = open(os.path.join(args.log, 'log.txt'), 'a')
    log.write('\n***   best val acc: %3f   ***' % best_prec1)
    log.write('\n***   best test acc: %3f   ***' % best_test_prec1)
    log.write('\n***   cond best test acc: %3f   ***' % cond_best_test_prec1)
    # end time
    log.write('\n-------------------------------------------\n')
    log.write(time.asctime(time.localtime(time.time())))
    log.write('\n-------------------------------------------\n')
    log.close()
Esempio n. 2
0
def Process2_PartNet(args):
    log_now = args.dataset + '/PartNet'
    process_name = 'partnet'
    if os.path.isfile(log_now + '/final.txt'):
        print('the Process2_PartNet is finished')
        return
    best_prec1 = 0
    model = Model_Construct(args, process_name)
    model = torch.nn.DataParallel(model).cuda()
    criterion = nn.BCELoss().cuda()
    # print(model)
    # print('the learning rate for the new added layer is set to 1e-3 to slow down the speed of learning.')
    optimizer = torch.optim.SGD(
        [{
            'params': model.module.conv_model.parameters(),
            'name': 'pre-trained'
        }, {
            'params': model.module.classification_stream.parameters(),
            'name': 'new-added'
        }, {
            'params': model.module.detection_stream.parameters(),
            'name': 'new-added'
        }],
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    start_epoch = args.start_epoch
    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoints '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("==> loaded checkpoint '{}'(epoch {})".format(
                args.resume, checkpoint['epoch']))
            args.resume = ''
        else:
            raise ValueError('The file to be resumed from is not exited',
                             args.resume)
    else:
        if not os.path.isdir(log_now):
            os.makedirs(log_now)
        log = open(os.path.join(log_now, 'log.txt'), 'w')
        state = {k: v for k, v in args._get_kwargs()}
        log.write(json.dumps(state) + '\n')
        log.close()
    cudnn.benchmark = True
    train_loader, val_loader = generate_dataloader(args, process_name, -1)
    if args.test_only:
        validate(val_loader, model, criterion, 2000, args)
    for epoch in range(start_epoch, args.epochs):
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, log_now,
              process_name, args)
        # evaluate on the val data
        prec1 = validate(val_loader, model, criterion, epoch, log_now,
                         process_name, args)
        # record the best prec1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if is_best:
            log = open(os.path.join(log_now, 'log.txt'), 'a')
            log.write("best acc %3f" % (best_prec1))
            log.close()
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'best_prec1': best_prec1,
                'optimizer': optimizer.state_dict(),
            }, is_best, log_now)
        svb_timer = time.time()
        if args.svb and epoch != (args.epochs - 1):
            svb(model, args)
            print(
                '!!!!!!!!!!!!!!!!!! the svb constrain is only applied on the classification stream.'
            )
            svb_det(model, args)
            print('the svb time is: ', time.time() - svb_timer)
    #download_scores(val_loader, model, log_now, process_name, args)
    log = open(os.path.join(log_now, 'final.txt'), 'w')
    log.write("best acc %3f" % (best_prec1))
    log.close()
Esempio n. 3
0
def Process5_Final_Result(args):
    ############################# Image Level Classifier #############################
    log_now = args.dataset + '/Image_Classifier'
    process_name = 'image_classifier'
    model = Model_Construct(args, process_name)
    model = torch.nn.DataParallel(model).cuda()
    pre_trained_model = log_now + '/model_best.pth.tar'
    checkpoint = torch.load(pre_trained_model)
    model.load_state_dict(checkpoint['state_dict'])
    train_loader, val_loader = generate_dataloader(args, process_name, -1)
    download_scores(val_loader, model, log_now, process_name, args)
    ############################# PartNet ############################################
    log_now = args.dataset + '/PartNet'
    process_name = 'partnet'
    model = Model_Construct(args, process_name)
    model = torch.nn.DataParallel(model).cuda()
    pre_trained_model = log_now + '/model_best.pth.tar'
    checkpoint = torch.load(pre_trained_model)
    model.load_state_dict(checkpoint['state_dict'])
    train_loader, val_loader = generate_dataloader(args, process_name)
    download_scores(val_loader, model, log_now, process_name, args)
    ############################# Three Part Level Classifiers #######################
    for i in range(
            args.num_part
    ):  ### if the process is break in this section, more modification is needed.
        log_now = args.dataset + '/Part_Classifiers_' + str(i)
        process_name = 'part_classifiers'
        model = Model_Construct(args, process_name)
        model = torch.nn.DataParallel(model).cuda()
        pre_trained_model = log_now + '/model_best.pth.tar'
        checkpoint = torch.load(pre_trained_model)
        model.load_state_dict(checkpoint['state_dict'])
        train_loader, val_loader = generate_dataloader(args, process_name, i)
        download_scores(val_loader, model, log_now, process_name, args)

    log_image = args.dataset + '/Image_Classifier'
    process_image = 'image_classifier'

    log_partnet = args.dataset + '/PartNet'
    process_partnet = 'partnet'

    log_part0 = args.dataset + '/Part_Classifiers_' + str(0)
    process_part0 = 'part_classifiers'

    log_part1 = args.dataset + '/Part_Classifiers_' + str(1)
    process_part1 = 'part_classifiers'

    log_part2 = args.dataset + '/Part_Classifiers_' + str(2)
    process_part2 = 'part_classifiers'

    image_table = torch.load(log_image + '/' + process_image + '.pth.tar')
    image_probability = image_table['scores']
    labels = image_table['labels']
    partnet_table = torch.load(log_partnet + '/' + process_partnet +
                               '.pth.tar')
    partnet_probability = partnet_table['scores']
    #######################
    part0_table = torch.load(log_part0 + '/' + process_part0 + '.pth.tar')
    part0_probability = part0_table['scores']
    ##########################
    part1_table = torch.load(log_part1 + '/' + process_part1 + '.pth.tar')
    part1_probability = part1_table['scores']
    ##########################
    part2_table = torch.load(log_part2 + '/' + process_part2 + '.pth.tar')
    part2_probability = part2_table['scores']
    ##########################

    probabilities_group = []
    probabilities_group.append(image_probability)
    probabilities_group.append(part0_probability)
    probabilities_group.append(part1_probability)
    probabilities_group.append(part2_probability)
    probabilities_group.append(partnet_probability)
    count = 0
    for i in range(len(labels)):
        probability = probabilities_group[0][i]
        for j in range(len(probabilities_group)):
            probability = probabilities_group[j][i] + probability
        probability = probability - probabilities_group[0][i]
        label = labels[i]
        value, index = probability.sort(0, descending=True)
        if index[0] == label:
            count = count + 1
    top1 = count / len(labels)
    print('the final results obtained by averaging part0-1-2 image partnet is',
          top1)
Esempio n. 4
0
def Process4_Part_Classifiers(args):
    for i in range(
            args.num_part
    ):  ### if the process is break in this section, more modification is needed.
        log_now = args.dataset + '/Part_Classifiers_' + str(i)
        process_name = 'part_classifiers'
        if os.path.isfile(log_now + '/final.txt'):
            print('the Process4_Part_Classifier is finished', i)
            continue
        best_prec1 = 0
        model = Model_Construct(args, process_name)
        model = torch.nn.DataParallel(model).cuda()
        criterion = nn.CrossEntropyLoss().cuda()
        optimizer = torch.optim.SGD(
            [{
                'params': model.module.base_conv.parameters(),
                'name': 'pre-trained'
            }, {
                'params': model.module.fc.parameters(),
                'lr': args.lr,
                'name': 'new-added'
            }],
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        log_image_model = args.dataset + '/Image_Classifier/model_best.pth.tar'
        checkpoint = torch.load(log_image_model)
        model.load_state_dict(checkpoint['state_dict'])
        print('load the cub fine-tuned model from:', log_image_model)
        start_epoch = args.start_epoch
        if args.resume:
            if os.path.isfile(args.resume):
                print("==> loading checkpoints '{}'".format(args.resume))
                checkpoint = torch.load(args.resume)
                start_epoch = checkpoint['epoch']
                best_prec1 = checkpoint['best_prec1']
                model.load_state_dict(checkpoint['state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                print("==> loaded checkpoint '{}'(epoch {})".format(
                    args.resume, checkpoint['epoch']))
                args.resume = ''
            else:
                raise ValueError('The file to be resumed from is not exited',
                                 args.resume)
        else:
            if not os.path.isdir(log_now):
                os.makedirs(log_now)
            log = open(os.path.join(log_now, 'log.txt'), 'w')
            state = {k: v for k, v in args._get_kwargs()}
            log.write(json.dumps(state) + '\n')
            log.close()
        cudnn.benchmark = True
        train_loader, val_loader = generate_dataloader(args, process_name, i)
        if args.test_only:
            validate(val_loader, model, criterion, 2000, args)
        for epoch in range(start_epoch, args.epochs_part):
            # train for one epoch
            train(train_loader, model, criterion, optimizer, epoch, log_now,
                  process_name, args)
            # evaluate on the val data
            prec1 = validate(val_loader, model, criterion, epoch, log_now,
                             process_name, args)
            # record the best prec1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            if is_best:
                log = open(os.path.join(log_now, 'log.txt'), 'a')
                log.write("best acc %3f" % (best_prec1))
                log.close()
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'arch': args.arch,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                }, is_best, log_now)
        #download_scores(val_loader, model, log_now, process_name, args)
        log = open(os.path.join(log_now, 'final.txt'), 'w')
        log.write("best acc %3f" % (best_prec1))
        log.close()
Esempio n. 5
0
def Process3_Download_Proposals(args):
    log_now = args.dataset + '/Download_Proposals'
    process_name = 'download_proposals'
    if os.path.isfile(log_now + '/final.txt'):
        print('the Process3_download proposals is finished')
        return

    model = Model_Construct(args, process_name)
    model = torch.nn.DataParallel(model).cuda()

    optimizer = torch.optim.SGD(
        [{
            'params': model.module.conv_model.parameters(),
            'name': 'pre-trained'
        }, {
            'params': model.module.classification_stream.parameters(),
            'name': 'new-added'
        }, {
            'params': model.module.detection_stream.parameters(),
            'name': 'new-added'
        }],
        lr=args.lr,
        momentum=args.momentum,
        weight_decay=args.weight_decay)
    log_partnet_model = args.dataset + '/PartNet/model_best.pth.tar'
    checkpoint = torch.load(log_partnet_model)
    model.load_state_dict(checkpoint['state_dict'])
    print('load the pre-trained partnet model from:', log_partnet_model)

    if args.resume:
        if os.path.isfile(args.resume):
            print("==> loading checkpoints '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("==> loaded checkpoint '{}'(epoch {})".format(
                args.resume, checkpoint['epoch']))
            args.resume = ''
        else:
            raise ValueError('The file to be resumed from is not exited',
                             args.resume)
    else:
        if not os.path.isdir(log_now):
            os.makedirs(log_now)
        log = open(os.path.join(log_now, 'log.txt'), 'w')
        state = {k: v for k, v in args._get_kwargs()}
        log.write(json.dumps(state) + '\n')
        log.close()

    cudnn.benchmark = True
    train_loader, val_loader = generate_dataloader(args, process_name)

    for epoch in range(1):

        download_part_proposals(train_loader, model, epoch, log_now,
                                process_name, 'train', args)

        best_prec1 = download_part_proposals(val_loader, model, epoch, log_now,
                                             process_name, 'val', args)

    log = open(os.path.join(log_now, 'final.txt'), 'w')
    log.write("best acc %3f" % (best_prec1))
    log.close()
Esempio n. 6
0
def main():
    global args, best_prec1, current_epoch, epoch_count_dataset
    current_epoch = 0
    epoch_count_dataset = 'source'
    args = opts()
    # ipdb.set_trace()
    # args = parser.parse_args()
    model_source = Model_Construct(args)
    # define-multi GPU
    model_source = torch.nn.DataParallel(model_source).cuda()

    # define loss function(criterion) and optimizer

    criterion = nn.CrossEntropyLoss().cuda()
    criterion_bce = nn.BCEWithLogitsLoss().cuda()
    np.random.seed(1)  ### fix the test data.
    random.seed(1)
    # optimizer = torch.optim.SGD(model.parameters(),
    # To apply different learning rate to different layer
    if args.domain_feature == 'original':
        print('domain feature is original')
        optimizer_feature = torch.optim.SGD(
            [{
                'params': model_source.module.base_conv.parameters(),
                'name': 'conv'
            }, {
                'params': model_source.module.domain_classifier.parameters(),
                'name': 'do_cl'
            }, {
                'params': model_source.module.fc.parameters(),
                'name': 'ca_cl'
            }],
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        optimizer_domain = torch.optim.SGD(
            [{
                'params': model_source.module.base_conv.parameters(),
                'name': 'conv'
            }, {
                'params': model_source.module.domain_classifier.parameters(),
                'name': 'do_cl'
            }, {
                'params': model_source.module.fc.parameters(),
                'name': 'ca_cl'
            }],
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)

    elif args.domain_feature == 'full_bilinear' or args.domain_feature == 'random_bilinear':
        print('the domain feature is full bilinear')
        optimizer_feature = torch.optim.SGD(
            [{
                'params': model_source.module.base_conv.parameters(),
                'name': 'conv'
            }, {
                'params': model_source.module.domain_classifier.parameters(),
                'name': 'do_cl'
            }, {
                'params': model_source.module.fc.parameters(),
                'name': 'ca_cl'
            }],
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
        optimizer_domain = torch.optim.SGD(
            [{
                'params': model_source.module.base_conv.parameters(),
                'name': 'conv'
            }, {
                'params': model_source.module.domain_classifier.parameters(),
                'name': 'do_cl'
            }, {
                'params': model_source.module.fc.parameters(),
                'name': 'ca_cl'
            }],
            lr=args.lr,
            momentum=args.momentum,
            weight_decay=args.weight_decay)
    else:
        raise ValueError('the requested domain feature is not available',
                         args.domain_feature)

    if args.resume:
        if os.path.isfile(args.resume):
            # raise ValueError('the resume function is not finished')
            print("==> loading checkpoints '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            current_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model_source.load_state_dict(checkpoint['source_state_dict'])
            print("==> loaded checkpoint '{}'(epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise ValueError('The file to be resumed from is not exited',
                             args.resume)
    else:
        if not os.path.isdir(args.log):
            os.makedirs(args.log)
        log = open(os.path.join(args.log, 'log.txt'), 'w')
        state = {k: v for k, v in args._get_kwargs()}
        log.write(json.dumps(state) + '\n')
        log.close()

    log = open(os.path.join(args.log, 'log.txt'), 'a')
    local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    log.write(local_time)
    log.close()

    cudnn.benchmark = True
    # process the data and prepare the dataloaders.
    # train_loader_source, val_loader_source, train_loader_target, val_loader_target = generate_dataloader(args)
    # train_loader, val_loader = generate_dataloader(args)
    train_loader_source, train_loader_target, val_loader_target, val_loader_source = generate_dataloader(
        args)

    # print('this is the first validation')
    # validate(val_loader_source, val_loader_target, model_source, model_target, criterion, 0, args)
    print('begin training')
    train_loader_source_batch = enumerate(train_loader_source)
    train_loader_target_batch = enumerate(train_loader_target)
    batch_number_s = len(train_loader_source)
    batch_number_t = len(train_loader_target)
    if batch_number_s < batch_number_t:
        epoch_count_dataset = 'target'

    for epoch in range(args.start_epoch, 1000000000000000000):
        # train for one epoch
        train_loader_source_batch, train_loader_target_batch, current_epoch, new_epoch_flag = train(
            train_loader_source, train_loader_source_batch,
            train_loader_target, train_loader_target_batch, model_source,
            criterion, criterion_bce, optimizer_feature, optimizer_domain,
            epoch, args, current_epoch, epoch_count_dataset)
        # train(train_loader, model, criterion, optimizer, epoch, args)
        # evaluate on the val data
        if new_epoch_flag:
            prec1 = validate(val_loader_target, model_source, criterion,
                             current_epoch, args)
            # prec1 = 1
            # record the best prec1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            if is_best:
                log = open(os.path.join(args.log, 'log.txt'), 'a')
                log.write('                        Target_T1 acc: %3f' %
                          (best_prec1))
                log.close()
                save_checkpoint(
                    {
                        'epoch': current_epoch + 1,
                        'arch': args.arch,
                        'source_state_dict': model_source.state_dict(),
                        'best_prec1': best_prec1,
                    }, is_best, args)
            if (current_epoch + 1) % args.domain_freq == 0:
                download_domain_scores(val_loader_target, val_loader_source,
                                       model_source, criterion, current_epoch,
                                       args)
        if current_epoch > args.epochs:
            break