Ejemplo n.º 1
0
def main():
    global args, sample_images_train_tensors, sample_images_val_tensors
    args = parser.parse_args()
    print('args.world_size: ', args.world_size)
    print('args.dist_backend: ', args.dist_backend)
    print('args.rank: ', args.rank)

    # more info on distributed PyTorch see https://pytorch.org/tutorials/intermediate/dist_tuto.html
    args.distributed = args.world_size >= 2
    print('is distributed: '.format(args.distributed))
    if args.distributed:
        dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                world_size=args.world_size, rank=args.rank)
        print('dist.init_process_group() finished.')

    # data sets and loaders
    dset_train = SpaceNetDataset(data_path_train, split_tags, transform=T.Compose([ToTensor()]))
    dset_val = SpaceNetDataset(data_path_val, split_tags, transform=T.Compose([ToTensor()]))
    logging.info('Training set size: {}, validation set size: {}'.format(
        len(dset_train), len(dset_val)))

    # need to instantiate these data loaders to produce the sample images because they need to be shuffled!
    loader_train = DataLoader(dset_train, batch_size=train_batch_size, shuffle=True,
                              num_workers=num_workers)  # shuffle True to reshuffle at every epoch

    loader_val = DataLoader(dset_val, batch_size=val_batch_size, shuffle=True,
                            num_workers=num_workers)

    # get one batch of sample images that are used to visualize the training progress throughout this run
    sample_images_train, sample_images_train_tensors = get_sample_images(loader_train, which_set='train')
    sample_images_val, sample_images_val_tensors = get_sample_images(loader_val, which_set='val')

    if args.distributed:
        # re-instantiate the training data loader to make distributed training possible
        train_batch_size_dist = train_batch_size * args.world_size
        logging.info('Using train_batch_size_dist {}.'.format(train_batch_size_dist))
        train_sampler = torch.utils.data.BatchSampler(
            torch.utils.data.distributed.DistributedSampler(dset_train),
            batch_size=train_batch_size_dist, drop_last=False)
        # TODO https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
        # check if need num_replicas and rank
        print('train_sampler created successfully.')
        loader_train = DataLoader(dset_train, num_workers=num_workers,
                                  pin_memory=True, batch_sample=train_sampler)

        loader_val = DataLoader(dset_val, batch_size=val_batch_size, shuffle=False,
                                num_workers=num_workers, pin_memory=True)
        print('both data loaders created successfully.')

    # checkpoint dir
    checkpoint_dir = out_checkpoint_dir

    logger_train = Logger('{}/train'.format(tensorboard_path))
    logger_val = Logger('{}/val'.format(tensorboard_path))
    log_sample_img_gt(sample_images_train, sample_images_val, logger_train, logger_val)
    logging.info('Logged ground truth image samples')

    num_classes = 3

    # larger model
    if model_choice == 'unet':
        model = Unet(feature_scale=feature_scale, n_classes=num_classes, is_deconv=True, in_channels=3, is_batchnorm=True)
    # year 2 best solution XD_XD's model, as the baseline model
    elif model_choice == 'unet_baseline':
        model = UnetBaseline(feature_scale=feature_scale, n_classes=num_classes, is_deconv=True, in_channels=3, is_batchnorm=True)
    else:
        sys.exit('Invalid model_choice {}, choose unet_baseline or unet'.format(model_choice))
    print('model instantiated.')

    if not args.distributed:
        model = model.to(device=device, dtype=dtype)  # move the model parameters to target device
        #model = torch.nn.DataParallel(model).cuda() # Batch AI example
    else:
        model.cuda()
        model = torch.nn.parallel.DistributedDataParallel(model)
        print('torch.nn.parallel.DistributedDataParallel() ran.')

    criterion = nn.CrossEntropyLoss(weight=loss_weights).to(device=device, dtype=dtype)

    # can also use Nesterov momentum in optim.SGD
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate,
    #                     momentum=0.9, nesterov=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # resume from a checkpoint if provided
    starting_epoch = 0
    best_acc = 0.0
    if os.path.isfile(starting_checkpoint_path):
        logging.info('Loading checkpoint from {0}'.format(starting_checkpoint_path))
        checkpoint = torch.load(starting_checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        starting_epoch = checkpoint['epoch']
        best_acc = checkpoint.get('best_acc', 0.0)
    else:
        logging.info('No valid checkpoint is provided. Start to train from scratch...')
        model.apply(weights_init)

    # run training or evaluation
    if evaluate_only:
        val_loss, val_acc = evaluate(loader_val, model, criterion)
        print('Evaluated on val set, loss is {}, accuracy is {}'.format(val_loss, val_acc))
        return

    step = starting_epoch * len(dset_train)

    for epoch in range(starting_epoch, total_epochs):
        logging.info('Epoch {} of {}'.format(epoch, total_epochs))

        # train for one epoch
        step = train(loader_train, model, criterion, optimizer, epoch, step, logger_train)

        # evaluate on val set
        logging.info('Evaluating model on the val set at the end of epoch {}...'.format(epoch))
        val_loss, val_acc = evaluate(loader_val, model, criterion)
        logging.info('\nEpoch {}, val loss is {}, val accuracy is {}\n'.format(epoch, step, val_loss, val_acc))
        logger_val.scalar_summary('val_loss', val_loss, step + 1)
        logger_val.scalar_summary('val_acc', val_acc, step + 1)
        # log the val images too

        # record the best accuracy; save checkpoint for every epoch
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        checkpoint_path = os.path.join(checkpoint_dir,
                                       'checkpoint_epoch{}_{}.pth.tar'.format(epoch, strftime("%Y-%m-%d-%H-%M-%S", localtime())))
        logging.info(
            'Saving to checkoutpoint file at {}. Is it the highest accuracy checkpoint so far: {}'.format(
                checkpoint_path, str(is_best)))
        save_checkpoint({
            'epoch': epoch + 1,  # saved checkpoints are numbered starting from 1
            'arch': model_choice,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_acc': best_acc
        }, is_best, checkpoint_path, checkpoint_dir)
Ejemplo n.º 2
0
def main():
    num_classes = 3

    # create checkpoint dir
    checkpoint_dir = 'checkpoints/{}'.format(experiment_name)
    os.makedirs(checkpoint_dir, exist_ok=True)

    logger_train = Logger('logs/{}/train'.format(experiment_name))
    logger_val = Logger('logs/{}/val'.format(experiment_name))
    log_sample_img_gt(sample_images_train, sample_images_val, logger_train,
                      logger_val)
    logging.info('Logged ground truth image samples')

    # larger model
    if model_choice == 'unet':
        model = Unet(feature_scale=feature_scale,
                     n_classes=num_classes,
                     is_deconv=True,
                     in_channels=3,
                     is_batchnorm=True)
    # year 2 best solution XD_XD's model, as the baseline model
    elif model_choice == 'unet_baseline':
        model = UnetBaseline(feature_scale=feature_scale,
                             n_classes=num_classes,
                             is_deconv=True,
                             in_channels=3,
                             is_batchnorm=True)
    else:
        sys.exit(
            'Invalid model_choice {}, choose unet_baseline or unet'.format(
                model_choice))

    model = model.to(device=device,
                     dtype=dtype)  # move the model parameters to CPU/GPU

    criterion = nn.CrossEntropyLoss(weight=loss_weights).to(device=device,
                                                            dtype=dtype)

    # can also use Nesterov momentum in optim.SGD
    # optimizer = optim.SGD(model.parameters(), lr=learning_rate,
    #                     momentum=0.9, nesterov=True)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # resume from a checkpoint if provided
    starting_epoch = 0
    best_acc = 0.0

    if os.path.isfile(starting_checkpoint_path):
        logging.info(
            'Loading checkpoint from {0}'.format(starting_checkpoint_path))
        checkpoint = torch.load(starting_checkpoint_path)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        starting_epoch = checkpoint['epoch']
        best_acc = checkpoint.get('best_acc', 0.0)
    else:
        logging.info(
            'No valid checkpoint is provided. Start to train from scratch...')
        model.apply(weights_init)

    if evaluate_only:
        val_loss, val_acc = evaluate(loader_val, model, criterion)
        print('Evaluated on val set, loss is {}, accuracy is {}'.format(
            val_loss, val_acc))
        return

    step = starting_epoch * len(dset_train)

    for epoch in range(starting_epoch, total_epochs):
        logging.info('Epoch {} of {}'.format(epoch, total_epochs))

        # train for one epoch
        step = train(loader_train, model, criterion, optimizer, epoch, step,
                     logger_train)

        # evaluate on val set
        logging.info(
            'Evaluating model on the val set at the end of epoch {}...'.format(
                epoch))
        val_loss, val_acc = evaluate(loader_val, model, criterion)
        logging.info('\nEpoch {}, val loss is {}, val accuracy is {}\n'.format(
            epoch, step, val_loss, val_acc))
        logger_val.scalar_summary('val_loss', val_loss, step + 1)
        logger_val.scalar_summary('val_acc', val_acc, step + 1)
        # log the val images too

        # record the best accuracy; save checkpoint for every epoch
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        checkpoint_path = os.path.join(
            checkpoint_dir, 'checkpoint_epoch{}_{}.pth.tar'.format(
                epoch, strftime("%Y-%m-%d-%H-%M-%S", localtime())))
        logging.info(
            'Saving to checkoutpoint file at {}. Is it the highest accuracy checkpoint so far: {}'
            .format(checkpoint_path, str(is_best)))
        save_checkpoint(
            {
                'epoch':
                epoch + 1,  # saved checkpoints are numbered starting from 1
                'arch': model_choice,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_acc': best_acc
            },
            is_best,
            checkpoint_path,
            checkpoint_dir)
Ejemplo n.º 3
0
def main():

    out_dir = './outputs' if config.out_dir is None else config.out_dir

    # if running locally, copy current version of scripts and config to output folder as a record
    if out_dir != './outputs':
        scripts_copy_dir = os.path.join(out_dir, 'repo_copy')
        cwd = os.getcwd()
        logging.info(f'cwd is {cwd}')
        if 'scripts' not in cwd:
            cwd = os.path.join(cwd, 'scripts')
        copytree(cwd, scripts_copy_dir)  # scripts_copy_dir cannot already exist
        logging.info(f'Copied over scripts to output dir at {scripts_copy_dir}')

    # create checkpoint dir
    checkpoint_dir = os.path.join(out_dir, config.experiment_name, 'checkpoints')
    os.makedirs(checkpoint_dir, exist_ok=True)

    # model
    model = config.model
    model = model.to(device=device)  # move the model parameters to CPU/GPU

    if config.loss_weights is not None:
        assert isinstance(config.loss_weights, torch.Tensor), \
            'config.loss_weight needs to be of Tensor type'
        assert len(config.loss_weights) == config.num_classes, \
            f'config.loss_weight has length {len(config.loss_weights)} but needs to equal to num_classes'
    criterion = nn.CrossEntropyLoss(weight=config.loss_weights).to(device=device)

    optimizer = optim.Adam(model.parameters(), lr=config.init_learning_rate)

    # resume from a checkpoint if provided
    starting_checkpoint_path = config.starting_checkpoint_path
    if starting_checkpoint_path and os.path.isfile(starting_checkpoint_path):
        logging.info('Loading checkpoint from {}'.format(starting_checkpoint_path))
        checkpoint = torch.load(starting_checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['state_dict'])

        # don't load the optimizer settings so that a newly specified lr can take effect
        # optimizer.load_state_dict(checkpoint['optimizer'])

        starting_epoch = checkpoint['epoch']  # we incremented epoch before saving it, so can just start here
        step = checkpoint.get('step', 0)
        best_acc = checkpoint.get('best_acc', 0.0)
        logging.info(f'Loaded checkpoint, starting epoch is {starting_epoch}, step is {step}, '
                     f'best accuracy is {best_acc}')
    else:
        logging.info('No valid checkpoint is provided. Start to train from scratch...')
        model.apply(weights_init)
        starting_epoch = 0
        best_acc = 0.0
        step = 0

    # data sets and loaders, which will be added to the global scope for easy access in other functions
    global dset_train, loader_train, dset_val, loader_val

    dset_train = config.dset_train
    loader_train = config.loader_train

    dset_val = config.dset_val
    loader_val = config.loader_val

    logging.info('Getting sample chips from val and train set...')
    samples_val = get_sample_images(which_set='val')
    samples_train = get_sample_images(which_set='train')

    # logging
    # run = Run.get_context()
    aml_run = None
    logger_train = Logger('train', config.log_dir, config.batch_size, aml_run)
    logger_val = Logger('val', config.log_dir, config.batch_size, aml_run)
    log_sample_img_gt(logger_train, logger_val,
                      samples_train, samples_val)
    logging.info('Logged image samples')

    if config.config_mode == ExperimentConfigMode.EVALUATION:
        val_loss, val_acc = evaluate(loader_val, model, criterion)
        logging.info(f'Evaluated on val set, loss is {val_loss}, accuracy is {val_acc}')
        return

    for epoch in range(starting_epoch, config.total_epochs):
        logging.info(f'\nEpoch {epoch} of {config.total_epochs}')

        # train for one epoch
        # we need the `step` concept for TensorBoard logging only
        train_start_time = datetime.now()
        step = train(loader_train, model, criterion, optimizer, epoch, step, logger_train)
        train_duration = datetime.now() - train_start_time

        # evaluate on val set
        logging.info('Evaluating model on the val set at the end of epoch {}...'.format(epoch))

        eval_start_time = datetime.now()
        val_loss, val_acc = evaluate(loader_val, model, criterion)
        eval_duration = datetime.now() - eval_start_time

        logging.info(f'\nEpoch {epoch}, step {step}, val loss is {val_loss}, val accuracy is {val_acc}\n')
        logger_val.scalar_summary('val_loss', val_loss, step)
        logger_val.scalar_summary('val_acc', val_acc, step)

        # visualize results on both train and val images
        visualize_result_on_samples(model, samples_train['chip'], logger_train, step, split='train')
        visualize_result_on_samples(model, samples_val['chip'], logger_val, step, split='val')

        # log values and gradients of the parameters (histogram summary)
        for tag, value in model.named_parameters():
            tag = tag.replace('.', '/')
            logger_train.histo_summary(tag, value.data.cpu().numpy(), step)
            logger_train.histo_summary(tag + '/grad', value.grad.data.cpu().numpy(), step)

        # record the best accuracy; save checkpoint for every epoch
        is_best = val_acc > best_acc
        best_acc = max(val_acc, best_acc)

        logging.info(
            f'Iterated through {step * config.batch_size} examples. Saved checkpoint for epoch {epoch}. '
            f'Is it the highest accuracy checkpoint so far: {is_best}\n')

        save_checkpoint({
            # add 1 so when we restart from it, we can just read it and proceed
            'epoch': epoch + 1,
            'step': step,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'val_acc': val_acc,
            'best_acc': best_acc
        }, is_best, checkpoint_dir)

        # log execution time for this epoch
        logging.info((f'epoch training_wcs duration is {train_duration.total_seconds()} seconds;'
                     f'evaluation duration is {eval_duration.total_seconds()} seconds'))
        logger_val.scalar_summary('epoch_duration_train', train_duration.total_seconds(), step)
        logger_val.scalar_summary('epoch_duration_val', eval_duration.total_seconds(), step)