예제 #1
0
        args.epochs = math.ceil(args.steps * args.batch_size /
                                len(train_loader))
    if args.epochs is None:
        args.epochs = args.steps
    for epoch in range(1, args.epochs + 1):
        epoch_start_time = time.time()
        cumulative_steps, cumulative_time, done = train(
            cumulative_steps, cumulative_time)
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s'.format(
            epoch, (time.time() - epoch_start_time)))
        if args.enable_gavel_iterator:
            if train_loader.done:
                break
            elif done:
                train_loader.complete()
                break
        elif done:
            break
        print('-' * 89)
    checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt')
    if not args.distributed or args.rank == 0:
        if args.distributed:
            state = {'model': model.module}
        else:
            state = {'model': model}
        if args.enable_gavel_iterator:
            train_loader.save_checkpoint(state, checkpoint_path)
        else:
            save_checkpoint(state, f)
except KeyboardInterrupt:
예제 #2
0
파일: main.py 프로젝트: xieydd/gavel
def main():
    global args, best_acc1, total_minibatches, total_elapsed_time
    args = parser.parse_args()
    torch.cuda.set_device(args.local_rank)

    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True
        warnings.warn('You have chosen to seed training. '
                      'This will turn on the CUDNN deterministic setting, '
                      'which can slow down your training considerably! '
                      'You may see unexpected behavior when restarting '
                      'from checkpoints.')

    args.distributed = False
    if args.master_addr is not None:
        args.distributed = True
        os.environ['MASTER_ADDR'] = args.master_addr
        os.environ['MASTER_PORT'] = str(args.master_port)
        dist.init_process_group(backend=args.dist_backend,
                                init_method=args.dist_url,
                                world_size=args.world_size,
                                rank=args.rank)

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        model = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        model = models.__dict__[args.arch]()

    model = model.cuda()
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank],
                output_device=args.local_rank)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    if args.enable_gavel_iterator:
        train_loader = GavelIterator(train_loader, args.checkpoint_dir,
                                     load_checkpoint, save_checkpoint)

    # Load from checkpoint.
    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)
    checkpoint_path = os.path.join(args.checkpoint_dir, 'model.chkpt')
    if os.path.exists(checkpoint_path):
        if args.enable_gavel_iterator:
            checkpoint = train_loader.load_checkpoint(args, checkpoint_path)
        else:
            checkpoint = load_checkpoint(args, checkpoint_path)

        if checkpoint is not None:
            args.start_epoch = checkpoint['epoch']
            # best_acc1 = checkpoint['best_acc1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(checkpoint_path, checkpoint['epoch']))

    if args.num_minibatches is not None:
        args.epochs = math.ceil(float(args.num_minibatches) *
                                args.batch_size / len(train_loader))

    epoch = args.start_epoch
    for epoch in range(args.start_epoch, args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch)

        # train for one epoch
        num_minibatches, elapsed_time, finished_epoch = \
            train(train_loader, model, criterion, optimizer,
                  epoch, total_minibatches,
                  max_minibatches=args.num_minibatches,
                  total_elapsed_time=total_elapsed_time,
                  max_duration=args.max_duration)
        total_minibatches += num_minibatches
        total_elapsed_time += elapsed_time

        if args.enable_gavel_iterator and train_loader.done:
            break
        elif (args.num_minibatches is not None and
            total_minibatches >= args.num_minibatches):
            if args.enable_gavel_iterator:
                train_loader.complete()
            break
        elif(args.max_duration is not None and
             total_elapsed_time >= args.max_duration):
            if args.enable_gavel_iterator:
                train_loader.complete()
            break

        # evaluate on validation set
        # acc1 = validate(val_loader, model, criterion)

        # remember best acc@1 and save checkpoint
        #best_acc1 = max(acc1, best_acc1)
    if not args.distributed or args.rank == 0:
        state = {
            'epoch': epoch,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            # 'best_acc1': best_acc1,
            'optimizer' : optimizer.state_dict(),
        }
        if args.enable_gavel_iterator:
            train_loader.save_checkpoint(state, checkpoint_path)
        else:
            save_checkpoint(state, checkpoint_path)