示例#1
0
文件: train.py 项目: ypwhs/doodle
        epoch = 5
    elif tag.isnumeric():
        epoch = int(tag)

if args.half:
    model = network_to_half(model)
model = model.cuda()
model.name = f'{args.model}_{args.tag}'

hvd.broadcast_parameters(model.state_dict(), root_rank=0)

scale_lr = batch_size * hvd.size() / 128
scheduler_warmup = LambdaLR(
    optimizer, lambda step: 1 + (scale_lr - 1) * step / len(valid_loader) / 5)
scheduler_train = MultiStepLR(optimizer, milestones=[80, 160, 200])
scheduler_train.base_lrs = [x * scale_lr for x in scheduler_train.base_lrs]

if epoch == 0:
    for i in [20, 40, 60, 80, 99]:
        epoch += 1
        train_loader = get_split_dataloader(f'{args.dataset}/train_k{i}.csv',
                                            width=width,
                                            batch_size=batch_size,
                                            transform=transform,
                                            num_workers=num_workers)
        train(model,
              train_loader,
              optimizer=optimizer,
              epoch=epoch,
              scheduler=scheduler_warmup,
              half=args.half)
示例#2
0
def main():
    global args, MODELS_DIR
    print args

    if args.dbg:
        MODELS_DIR = join(MODELS_DIR, 'dbg')

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    # if torch.cuda.is_available() and not args.cuda:
    #     print("WARNING: You have a CUDA device, so you should probably run with --cuda")
    print 'CudNN:', torch.backends.cudnn.version()
    print 'Run on {} GPUs'.format(torch.cuda.device_count())
    cudnn.benchmark = True

    is_sobel = args.arch.endswith('Sobel')
    print 'is_sobel', is_sobel

    print("=> creating model '{}'".format(args.arch))
    model = models.__dict__[args.arch](
        num_classes=args.num_clusters if args.unsupervised else 1000,
        dropout7_prob=args.dropout7_prob)
    model = torch.nn.DataParallel(model).cuda()
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(),
                                args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    experiment = "{}_lr{}_{}{}".format(
        args.arch, args.lr, 'unsup' if args.unsupervised else 'labels',
        '_v2' if args.imagenet_version == 2 else '')
    if args.unsupervised:
        experiment += '{sobel_norm}_nc{nc}_l{clustering_layer}_rec{rec_epoch}{reset_fc}'.format(
            sobel_norm='_normed' if args.sobel_normalized else '',
            nc=args.num_clusters,
            clustering_layer=args.clustering_layer,
            rec_epoch=args.recluster_epoch,
            reset_fc='_reset-fc' if args.reset_fc else '')

    checkpoint = None
    if args.output_dir is None:
        args.output_dir = join(MODELS_DIR, experiment + '_' + args.exp_suffix)

    if args.output_dir is not None and os.path.exists(args.output_dir):
        ckpt_path = join(
            args.output_dir, 'checkpoint.pth.tar'
            if not args.from_best else 'model_best.pth.tar')
        if not os.path.isfile(ckpt_path):
            print "=> no checkpoint found at '{}'\nUsing model_best.pth.tar".format(
                ckpt_path)
            ckpt_path = join(args.output_dir, 'model_best.pth.tar')

        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(ckpt_path))
            checkpoint = torch.load(ckpt_path)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                ckpt_path, checkpoint['epoch']))
        else:
            print "=> no checkpoint found at '{}'\nUsing model_best_nmi.pth.tar".format(
                ckpt_path)
            ckpt_path = join(args.output_dir, 'model_best_nmi.pth.tar')

        if os.path.isfile(ckpt_path):
            print("=> loading checkpoint '{}'".format(ckpt_path))
            checkpoint = torch.load(ckpt_path)
            print("=> loaded checkpoint '{}' (epoch {})".format(
                ckpt_path, checkpoint['epoch']))
        else:
            print "=> no checkpoint found at '{}'".format(ckpt_path)
            ans = None
            while ans != 'y' and ans != 'n':
                ans = raw_input('Clear the dir {}? [y/n] '.format(
                    args.output_dir)).lower()
            if ans.lower() == 'y':
                shutil.rmtree(args.output_dir)
            else:
                print 'Just write in the same dir.'
                # raise IOError("=> no checkpoint found at '{}'".format(ckpt_path))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    print 'Output dir:', args.output_dir

    start_epoch = 0
    best_score = 0
    best_nmi = 0
    if checkpoint is not None:
        start_epoch = checkpoint['epoch']
        if 'best_score' in checkpoint:
            best_score = checkpoint['best_score']
        else:
            print 'WARNING! NO best "score_found" in checkpoint!'
            best_score = 0
        if 'nmi' in checkpoint:
            print 'Current NMI/GT:', checkpoint['nmi']
        if 'best_nmi' in checkpoint:
            best_nmi = checkpoint['best_nmi']
            print 'Best NMI/GT:', best_nmi
        print 'Best score:', best_score
        if 'cur_score' in checkpoint:
            print 'Current score:', checkpoint['cur_score']
        model.load_state_dict(checkpoint['state_dict'])
        print 'state dict loaded'
        optimizer.load_state_dict(checkpoint['optimizer'])
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.lr
            param_group['initial_lr'] = args.lr
    logger = SummaryWriter(log_dir=args.output_dir)

    ### Data loading ###
    num_gt_classes = 1000
    split_dirs = {
        'train':
        join(args.data,
             'train' if args.imagenet_version == 1 else 'train_256'),
        'val':
        join(
            args.data, 'val' if args.imagenet_version == 1 else 'val_256'
        )  # we get lower accuracy with cal_256, probably because of jpeg compression
    }
    dataset_indices = dict()
    for key in ['train', 'val']:
        index_path = join(args.data,
                          os.path.basename(split_dirs[key]) + '_index.json')

        if os.path.exists(index_path):
            with open(index_path) as json_file:
                dataset_indices[key] = json.load(json_file)
        else:
            print 'Indexing ' + key
            dataset_indices[key] = index_imagenet(split_dirs[key], index_path)

    assert dataset_indices['train']['class_to_idx'] == \
           dataset_indices['val']['class_to_idx']
    if args.dbg:
        max_images = 1000
        print 'DBG: WARNING! Trauncate train datset to {} images'.format(
            max_images)
        dataset_indices['train']['samples'] = dataset_indices['train'][
            'samples'][:max_images]
        dataset_indices['val']['samples'] = dataset_indices['val'][
            'samples'][:max_images]

    num_workers = args.workers  # if args.unsupervised else max(1, args.workers / 2)

    print '[TRAIN]...'
    if args.unsupervised:
        train_loader_gt = create_data_loader(
            split_dirs['train'],
            dataset_indices['train'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug='random_crop_flip',
            shuffle='shuffle' if not args.fast_dataflow else 'shuffle_buffer',
            num_workers=num_workers,
            use_fast_dataflow=args.fast_dataflow,
            buffer_size=args.buffer_size)
        eval_gt_aug = '10_crop'
        val_loader_gt = create_data_loader(
            split_dirs['val'],
            dataset_indices['val'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug=eval_gt_aug,
            batch_size=26,  # WARNING. Decrease the batch size because of Memory
            shuffle='shuffle',
            num_workers=num_workers,
            use_fast_dataflow=False,
            buffer_size=args.buffer_size)
    else:
        train_loader = create_data_loader(
            split_dirs['train'],
            dataset_indices['train'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug='random_crop_flip',
            shuffle='shuffle' if not args.fast_dataflow else 'shuffle_buffer',
            num_workers=num_workers,
            use_fast_dataflow=args.fast_dataflow,
            buffer_size=args.buffer_size)
        print '[VAL]...'
        # with GT labels!
        val_loader = create_data_loader(
            split_dirs['val'],
            dataset_indices['val'],
            is_sobel,
            sobel_normalized=args.sobel_normalized,
            aug='central_crop',
            batch_size=args.batch_size,
            shuffle='shuffle' if not args.fast_dataflow else None,
            num_workers=num_workers,
            use_fast_dataflow=args.fast_dataflow,
            buffer_size=args.buffer_size)
    ###############################################################################

    # StepLR(optimizer, step_size=args.decay_step, gamma=args.decay_gamma)
    if args.scheduler == 'multi_step':
        scheduler = MultiStepLR(optimizer,
                                milestones=[30, 60, 80],
                                gamma=args.decay_gamma)
    elif args.scheduler == 'multi_step2':
        scheduler = MultiStepLR(optimizer,
                                milestones=[50, 100],
                                gamma=args.decay_gamma)
    elif args.scheduler == 'cyclic':
        print 'Using Cyclic LR!'
        cyclic_lr = CyclicLr(start_epoch if args.reset_lr else 0,
                             init_lr=args.lr,
                             num_epochs_per_cycle=args.cycle,
                             epochs_pro_decay=args.decay_step,
                             lr_decay_factor=args.decay_gamma)
        scheduler = LambdaLR(optimizer, lr_lambda=cyclic_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))
    elif args.scheduler == 'step':
        step_lr = StepMinLr(start_epoch if args.reset_lr else 0,
                            init_lr=args.lr,
                            epochs_pro_decay=args.decay_step,
                            lr_decay_factor=args.decay_gamma,
                            min_lr=args.min_lr)

        scheduler = LambdaLR(optimizer, lr_lambda=step_lr)
        scheduler.base_lrs = list(
            map(lambda group: 1.0, optimizer.param_groups))
    else:
        assert False, 'wrong scheduler: ' + args.scheduler

    print 'scheduler.base_lrs=', scheduler.base_lrs
    logger.add_scalar('data/batch_size', args.batch_size, start_epoch)

    save_epoch = 50
    if not args.unsupervised:
        validate_epoch = 1
    else:
        validate_epoch = 50
        labels_holder = {
        }  # utility container to save labels from the previous clustering step

    last_lr = 100500
    for epoch in range(start_epoch, args.epochs):
        nmi_gt = None
        if epoch == start_epoch:
            if not args.unsupervised:
                validate(val_loader,
                         model,
                         criterion,
                         epoch - 1,
                         logger=logger)
            # elif start_epoch == 0:
            #     print 'validate_gt_linear'
            #     validate_gt_linear(train_loader_gt, val_loader_gt, num_gt_classes,
            #                        model, args.eval_layer, criterion, epoch - 1, lr=0.01,
            #                        num_train_epochs=2,
            #                        logger=logger, tag='val_gt_{}_{}'.format(args.eval_layer, eval_gt_aug))

        if args.unsupervised and (epoch == start_epoch
                                  or epoch % args.recluster_epoch == 0):
            train_loader, nmi_gt = unsupervised_clustering_step(
                epoch, model, is_sobel, args.sobel_normalized, split_dirs,
                dataset_indices, num_workers, labels_holder, logger,
                args.fast_dataflow)
            if args.reset_fc:
                model.module.reset_fc8()
            try:
                with open(join(args.output_dir, 'labels_holder.json'),
                          'w') as f:
                    for k in labels_holder.keys():
                        labels_holder[k] = np.asarray(
                            labels_holder[k]).tolist()
                    json.dump(labels_holder, f)
            except Exception as e:
                print e

        scheduler.step(epoch=epoch)
        if last_lr != scheduler.get_lr()[0]:
            last_lr = scheduler.get_lr()[0]
            print 'LR := {}'.format(last_lr)
        logger.add_scalar('data/lr', scheduler.get_lr()[0], epoch)
        logger.add_scalar('data/v', args.imagenet_version, epoch)
        logger.add_scalar('data/weight_decay', args.weight_decay, epoch)
        logger.add_scalar('data/dropout7_prob', args.dropout7_prob, epoch)

        top1_avg, top5_avg, loss_avg = \
            train(train_loader, model, criterion, optimizer,
                  epoch, args.epochs,
                  log_iter=100, logger=logger)

        if (epoch + 1) % validate_epoch == 0:
            # evaluate on validation set
            if not args.unsupervised:
                score = validate(val_loader,
                                 model,
                                 criterion,
                                 epoch,
                                 logger=logger)
            else:
                score = validate_gt_linear(
                    train_loader_gt,
                    val_loader_gt,
                    num_gt_classes,
                    model,
                    args.eval_layer,
                    criterion,
                    epoch,
                    lr=0.01,
                    num_train_epochs=args.epochs_train_linear,
                    logger=logger,
                    tag='val_gt_{}_{}'.format(args.eval_layer, eval_gt_aug))

            # remember best prec@1 and save checkpoint
            is_best = score > best_score
            best_score = max(score, best_score)
            best_ckpt_suffix = ''
        else:
            score = None
            if nmi_gt is not None and nmi_gt > best_nmi:
                best_nmi = nmi_gt
                best_ckpt_suffix = '_nmi'
                is_best = True
            else:
                is_best = False
                best_ckpt_suffix = ''

        if (epoch + 1) % save_epoch == 0:
            filepath = join(args.output_dir,
                            'checkpoint-{:05d}.pth.tar'.format(epoch + 1))
        else:
            filepath = join(args.output_dir, 'checkpoint.pth.tar')
        save_dict = {
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_score': best_score,
            'top1_avg_accuracy_train': top1_avg,
            'optimizer': optimizer.state_dict(),
        }
        if nmi_gt is not None:
            save_dict['nmi'] = nmi_gt
            save_dict['best_nmi'] = best_nmi
        if score is not None:
            save_dict['cur_score'] = score
        save_checkpoint(save_dict,
                        is_best=is_best,
                        filepath=filepath,
                        best_suffix=best_ckpt_suffix)