示例#1
0
import os
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings('ignore')


if __name__ == '__main__':
    ### make save_dir
    _logger = mkExpDir(args)

    ### dataloader of training set and testing set
    if args.adaptTrain:
        _dataloader = adaptDataloader.get_dataloader(args)
    else:
        _dataloader = dataloader.get_dataloader(args) if (not args.test) else None

    print(type(_dataloader))
    print('###################################################')
    print(_dataloader)
    ### device and model
    device = torch.device('cpu' if args.cpu else 'cuda')
    _model = TTSR.TTSR(args).to(device)
    _dualmodel = DualModel.DualModel(args).to(device)
    if ((not args.cpu) and (args.num_gpu > 1)):
        _model = nn.DataParallel(_model, list(range(args.num_gpu)))

    ### loss
    _loss_all = get_loss_dict(args, _logger)

    ### trainer
示例#2
0
def main(args: argparse.Namespace):
    logger = CompleteLogger(args.log, args.phase)
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')
    cudnn.benchmark = True

    ## 加载数据
    # train_source_loader, train_source_iter = get_digits_dataloader(args, split='train', phase='train', domain='source')
    # train_target_loader, train_target_iter = get_digits_dataloader(args, split='test', phase='train', domain='target')
    # val_loader, val_iter = get_digits_dataloader(args, split='test', phase='val', domain='target')
    # test_loader, test_iter = val_loader, val_iter
    train_source_loader, train_source_iter = get_dataloader(args, phase='train', domain='source')
    train_target_loader, train_target_iter = get_dataloader(args, phase='train', domain='target')
    val_loader, val_iter = get_dataloader(args, phase='val', domain='target')
    test_loader, test_iter = val_loader, val_iter
    ## 创建模型
    print("=> using pre-trained model '{}'".format(args.arch))
    G = resnet50(pretrained=True).to(device)
    num_classes = train_source_loader.dataset.num_classes
    #### 分类器head
    F1 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim).to(device)
    F2 = ImageClassifierHead(G.out_features, num_classes, args.bottleneck_dim).to(device)
    ## 定义优化算法,学习率, 损失评价
    optimizer_g = SGD(G.parameters(), lr=args.lr, weight_decay=0.0005)
    optimizer_f = SGD([
        {'params': F1.parameters()},
        {'params': F2.parameters()},
    ], momentum=0.9, lr=args.lr, weight_decay=0.0005)

    if args.phase != 'train':
        checkpoints = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
        G.load_state_dict(checkpoints['G'])
        F1.load_state_dict(checkpoints['F1'])
        F2.load_state_dict(checkpoints['F2'])

    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = G.to(device)
        source_feature = collect_feature(train_source_loader, feature_extractor, device)
        target_feature = collect_feature(train_target_loader, feature_extractor, device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature, device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_loader, G, F1, F2, args)
        print(acc1)
        return

    ### 开始迭代训练
    best_acc1 = 0.
    best_results = None
    for epoch in range(args.epochs):
        train(train_source_iter, train_target_iter, G, F1, F2, optimizer_g, optimizer_f, epoch, args)
        results = validate(val_loader, G, F1, F2, args)
        torch.save({
            'G': G.state_dict(),
            'F1': F1.state_dict(),
            'F2': F2.state_dict(),
        }, logger.get_checkpoint_path('latest'))
        if max(results) > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
            best_acc1 = max(results)
            best_results = results
    print("best_acc1 = {:3.1f}, results = {}".format(best_acc1, best_results))
    checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
    G.load_state_dict(checkpoint['G'])
    F1.load_state_dict(checkpoint['F1'])
    F2.load_state_dict(checkpoint['F2'])
    results = validate(test_loader, G, F1, F2, args)
    print("test_acc1 = {:3.1f}".format(max(results)))
    logger.close()
示例#3
0
文件: main.py 项目: liuwenbo3/TTSR
from option import args
from utils import mkExpDir
from dataset import dataloader
from model import TTSR
from loss.loss import get_loss_dict
from trainer import Trainer

import os
import torch

if __name__ == '__main__':
    ### make save_dir
    _logger = mkExpDir(args)

    ### dataloader of training set and testing set
    _dataloader = dataloader.get_dataloader(args)

    ### device and model
    device = torch.device('cpu' if args.cpu else 'cuda')
    _model = TTSR.TTSR(args).to(device)
    if ((not args.cpu) and (args.num_gpu > 1)):
        _model = nn.DataParallel(_model, list(range(args.num_gpu)))

    ### loss
    _loss_all = get_loss_dict(args, _logger)

    ### trainer
    t = Trainer(args, _logger, _dataloader, _model, _loss_all)

    ### eval / train
    if (args.eval):
示例#4
0
def main(args):

    logger = CompleteLogger(args.log, args.phase)
    if args.seed is not None:
        random.seed((args.seed))
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn(
            'You have chosen to seed training. '
            'This will turn on the CUDNN deterministic setting, '
            'which can slow down your training considerably! '
            'You may see unexpected behavior when restarting from checkpoints.'
        )
    cudnn.benchmark = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ## 加载数据
    source_dataloader, train_source_iter = get_dataloader(args,
                                                          phase='train',
                                                          domain='source')
    target_dataloader, train_target_iter = get_dataloader(args,
                                                          phase='train',
                                                          domain='target')
    val_dataloader, val_target_iter = get_dataloader(args,
                                                     phase='val',
                                                     domain='tartget')
    test_dataloader, test_target_iter = val_dataloader, val_target_iter
    ## 创建模型
    print("=> using pre-trained model '{}'".format(args.arch))
    backbone = resnet50(pretrained=True)
    classifier = ImageClassifier(
        backbone, source_dataloader.dataset.num_classes).to(device)
    domain_discirminator = DomainDiscriminator(
        in_feature=classifier.features_dim, hidden_size=1024).to(device)
    ## 定义优化算法,学习率, 损失评价
    optimizer = torch.optim.SGD(classifier.get_parameters() +
                                domain_discirminator.get_parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)
    lr_schedule = LambdaLR(
        optimizer, lambda x: args.lr *
        (1. + args.lr_gamma * float(x))**(-args.lr_decay))
    domain_adv = DomainAdversarialLoss(domain_discirminator)

    ##################

    if args.phase != 'train':
        checkpoint = torch.load(logger.get_checkpoint_path('best'),
                                map_location='cpu')
        classifier.load_state_dict(checkpoint)

    if args.phase == 'analysis':
        # extract features from both domains
        feature_extractor = nn.Sequential(classifier.backbone,
                                          classifier.bottleneck).to(device)
        source_feature = collect_feature(source_dataloader, feature_extractor,
                                         device)
        target_feature = collect_feature(target_dataloader, feature_extractor,
                                         device)
        # plot t-SNE
        tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.png')
        tsne.visualize(source_feature, target_feature, tSNE_filename)
        print("Saving t-SNE to", tSNE_filename)
        # calculate A-distance, which is a measure for distribution discrepancy
        A_distance = a_distance.calculate(source_feature, target_feature,
                                          device)
        print("A-distance =", A_distance)
        return

    if args.phase == 'test':
        acc1 = validate(test_dataloader, test_target_iter, classifier, device,
                        args)
        print(acc1)
        return
    ################
    ## 开始迭代训练
    best_acc1 = 0.
    for epoch in range(args.epochs):
        train(train_source_iter, train_target_iter, classifier, domain_adv,
              optimizer, lr_schedule, epoch, device, args)  # 训练
        acc1 = validate(val_dataloader, val_target_iter, classifier, device,
                        args)  # 验证
        torch.save(classifier.state_dict(),
                   logger.get_checkpoint_path('latest'))  # 保存模型
        if acc1 > best_acc1:
            shutil.copy(logger.get_checkpoint_path('latest'),
                        logger.get_checkpoint_path('best'))
        best_acc1 = max(acc1, best_acc1)
    print('best_acc1 = {:3.1f}'.format(best_acc1))
    classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
    acc1 = validate(test_dataloader, test_target_iter, classifier, device,
                    args)
    print('test_acc1 = {:3.1f}'.format(acc1))
    logger.close()
示例#5
0
 def get_dataloader(self):
     return get_dataloader(self, self.base_dataset.collate_fn, self.args)