Пример #1
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for k, v in config['common'].items():
        setattr(args, k, v)
    torch.cuda.manual_seed(int(time.time()) % 1000)
    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch.startswith('inception_v3'):
        print('inception_v3 without aux_logits!')
        image_size = 341
        input_size = 299
        model = models.__dict__[args.arch](aux_logits=True,
                                           num_classes=1000,
                                           pretrained=args.pretrained)
    else:
        image_size = 182
        input_size = 160
        student_model = models.__dict__[args.arch](
            num_classes=args.num_classes,
            pretrained=args.pretrained,
            avgpool_size=input_size / 32)
    student_model.cuda()
    student_params = list(student_model.parameters())

    student_optimizer = torch.optim.Adam(student_model.parameters(),
                                         args.base_lr * 0.1)

    args.save_path = "checkpoint/" + args.exp_name

    if not osp.exists(args.save_path):
        os.mkdir(args.save_path)

    tb_logger = SummaryWriter(args.save_path)
    logger = create_logger('global_logger', args.save_path + '/log.txt')

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))

    criterion = nn.CrossEntropyLoss()
    print("Build network")
    last_iter = -1
    best_prec1 = 0
    load_state(args.save_path + "/ckptmodel_best.pth.tar", student_model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    se_normalize = se_transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])

    border_value = int(np.mean([0.485, 0.456, 0.406]) * 255 + 0.5)
    test_aug = se_transforms.ImageAugmentation(True,
                                               0,
                                               rot_std=0.0,
                                               scale_u_range=[0.75, 1.333],
                                               affine_std=0,
                                               scale_x_range=None,
                                               scale_y_range=None)

    val_dataset = NormalDataset(args.val_root,
                                "./data/visda/list/validation_list.txt",
                                transform=transforms.Compose([
                                    se_transforms.ScaleAndCrop(
                                        (input_size, input_size), args.padding,
                                        False, np.array([0.485, 0.456, 0.406]),
                                        np.array([0.229, 0.224, 0.225]))
                                ]),
                                is_train=False,
                                args=args)

    val_loader = DataLoader(val_dataset,
                            batch_size=1,
                            shuffle=False,
                            num_workers=args.workers)

    val_multi_dataset = NormalDataset(
        args.val_root,
        "./data/visda/list/validation_list.txt",
        transform=transforms.Compose([
            se_transforms.ScaleCropAndAugmentAffineMultiple(
                16, (input_size, input_size), args.padding, True, test_aug,
                border_value, np.array([0.485, 0.456, 0.406]),
                np.array([0.229, 0.224, 0.225]))
        ]),
        is_train=False,
        args=args)

    val_multi_loader = DataLoader(val_multi_dataset,
                                  batch_size=1,
                                  shuffle=False,
                                  num_workers=args.workers)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        student_optimizer, args.lr_steps, args.lr_gamma)
    #logger.info('{}'.format(args))
    validate(val_loader, student_model, criterion)
    validate_multi(val_multi_loader, student_model, criterion)
Пример #2
0
def main():
    global args, best_prec1
    args = parser.parse_args()

    with open(args.config) as f:
        config = yaml.load(f)

    for k, v in config['common'].items():
        setattr(args, k, v)
    torch.cuda.manual_seed(int(time.time()) % 1000)
    # create model
    print("=> creating model '{}'".format(args.arch))
    if args.arch.startswith('inception_v3'):
        print('inception_v3 without aux_logits!')
        image_size = 341
        input_size = 299
        model = models.__dict__[args.arch](aux_logits=True,
                                           num_classes=1000,
                                           pretrained=args.pretrained)
    else:
        image_size = 182
        input_size = 160
        student_model = models.__dict__[args.arch](
            num_classes=args.num_classes,
            pretrained=args.pretrained,
            avgpool_size=input_size / 32)
    student_model.cuda()
    student_params = list(student_model.parameters())

    args.save_path = "checkpoint/" + args.exp_name

    if not osp.exists(args.save_path):
        os.mkdir(args.save_path)

    tb_logger = SummaryWriter(args.save_path)
    logger = create_logger('global_logger', args.save_path + '/log.txt')

    for key, val in vars(args).items():
        logger.info("{:16} {}".format(key, val))

    logger.info("filename {}".format(osp.basename(__file__)))
    df_train = pd.read_csv(args.train_source_source, sep=" ", header=None)
    weight = df_train[1].value_counts()
    weight = weight.sort_index()
    weight = len(df_train) / weight
    weight = torch.from_numpy(weight.values).float().cuda()
    weight = (weight - torch.min(weight)) / (torch.max(weight) -
                                             torch.min(weight))
    weight = (1 - torch.mean(weight)) + weight
    # define loss function (criterion) and optimizer
    if not args.use_weight:
        weight = None
    criterion = nn.CrossEntropyLoss(weight=weight)
    ignored_params = list(map(id, student_model.classifier.parameters()))

    base_params = filter(
        lambda p: id(p) not in ignored_params and p.requires_grad == True,
        student_model.parameters())
    if args.pretrained == True:
        student_optimizer = torch.optim.Adam(
            [{
                'params': base_params
            }, {
                'params': student_model.classifier.parameters(),
                'lr': args.base_lr
            }], args.base_lr * 0.1)
    else:
        student_optimizer = torch.optim.Adam(student_model.parameters(),
                                             args.base_lr)
    # optionally resume from a checkpoint
    print("Build network")
    last_iter = -1
    best_prec1 = 0
    if args.load_path:
        print(args.load_path)
        if args.resume_opt:
            best_prec1, last_iter = load_state(args.load_path,
                                               model,
                                               optimizer=student_optimizer)
        else:
            load_state(args.load_path, model)

    cudnn.benchmark = True

    # Data loading code
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    se_normalize = se_transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                           std=[0.229, 0.224, 0.225])

    src_aug = se_transforms.ImageAugmentation(
        args.src_hflip,
        args.src_xlat_range,
        args.src_affine_std,
        rot_std=args.src_rot_std,
        intens_scale_range_lower=args.src_intens_scale_range_lower,
        intens_scale_range_upper=args.src_intens_scale_range_upper,
        colour_rot_std=args.src_colour_rot_std,
        colour_off_std=args.src_colour_off_std,
        greyscale=args.src_greyscale,
        scale_u_range=args.src_scale_u_range,
        scale_x_range=(None, None),
        scale_y_range=(None, None),
        cutout_probability=args.src_cutout_prob,
        cutout_size=args.src_cutout_size)
    tgt_aug = se_transforms.ImageAugmentation(
        args.tgt_hflip,
        args.tgt_xlat_range,
        args.tgt_affine_std,
        rot_std=args.tgt_rot_std,
        intens_scale_range_lower=args.tgt_intens_scale_range_lower,
        intens_scale_range_upper=args.tgt_intens_scale_range_upper,
        colour_rot_std=args.tgt_colour_rot_std,
        colour_off_std=args.tgt_colour_off_std,
        greyscale=args.tgt_greyscale,
        scale_u_range=args.tgt_scale_u_range,
        scale_x_range=[None, None],
        scale_y_range=[None, None],
        cutout_probability=args.tgt_cutout_prob,
        cutout_size=args.tgt_cutout_size)

    border_value = int(np.mean([0.485, 0.456, 0.406]) * 255 + 0.5)
    test_aug = se_transforms.ImageAugmentation(
        args.tgt_hflip,
        args.tgt_xlat_range,
        0.0,
        rot_std=0.0,
        scale_u_range=args.tgt_scale_u_range,
        scale_x_range=[None, None],
        scale_y_range=[None, None])

    train_source_dataset = NormalDataset(
        args.train_source_root,
        args.train_source_source,
        transform=transforms.Compose([
            se_transforms.ScaleCropAndAugmentAffine(
                (input_size, input_size), args.padding, True, src_aug,
                border_value, np.array([0.485, 0.456, 0.406]),
                np.array([0.229, 0.224, 0.225]))
        ]),
        args=args)

    train_target_dataset = TeacherDataset(
        args.train_target_root,
        args.train_target_source,
        transform=transforms.Compose([
            se_transforms.ScaleCropAndAugmentAffinePair(
                (input_size, input_size), args.padding, 0, True, tgt_aug,
                border_value, np.array([0.485, 0.456, 0.406]),
                np.array([0.229, 0.224, 0.225]))
        ]),
        args=args)

    val_dataset = NormalDataset(args.val_root,
                                args.val_source,
                                transform=transforms.Compose([
                                    se_transforms.ScaleAndCrop(
                                        (input_size, input_size), args.padding,
                                        False, np.array([0.485, 0.456, 0.406]),
                                        np.array([0.229, 0.224, 0.225]))
                                ]),
                                is_train=False,
                                args=args)

    train_source_loader = DataLoader(train_source_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers)

    train_target_loader = DataLoader(train_target_dataset,
                                     batch_size=args.batch_size,
                                     shuffle=True,
                                     num_workers=args.workers)

    val_loader = DataLoader(val_dataset,
                            batch_size=args.batch_size,
                            shuffle=False,
                            num_workers=args.workers)

    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
        student_optimizer, args.lr_steps, args.lr_gamma)
    #logger.info('{}'.format(args))
    if args.evaluate:
        validate(val_loader, student_model, criterion)
        return

    train(train_source_loader,
          train_target_loader,
          val_loader,
          student_model,
          criterion,
          student_optimizer=student_optimizer,
          lr_scheduler=lr_scheduler,
          start_iter=last_iter + 1,
          tb_logger=tb_logger)