Пример #1
0
def main(test, s3_data, batch, debug):
    """Train a semantic segmentation FPN model on the CamVid-Tiramisu dataset."""
    if batch:
        run_on_batch(test, debug)

    # Setup options
    batch_sz = 8
    num_workers = 4
    num_epochs = 20
    lr = 1e-4
    backbone_arch = 'resnet18'
    sample_pct = 1.0

    if test:
        batch_sz = 1
        num_workers = 0
        num_epochs = 2
        sample_pct = 0.01

    # Setup data
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name
    output_dir = local_output_uri
    make_dir(output_dir)

    data_dir = download_data(s3_data, tmp_dir)
    data = get_databunch(data_dir,
                         sample_pct=sample_pct,
                         batch_sz=batch_sz,
                         num_workers=num_workers)
    print(data)
    plot_data(data, output_dir)

    # Setup and train model
    num_classes = data.c
    model = SegmentationFPN(backbone_arch, num_classes)
    metrics = [acc_camvid]
    learn = Learner(data,
                    model,
                    metrics=metrics,
                    loss_func=SegmentationFPN.loss,
                    path=output_dir)
    learn.unfreeze()

    callbacks = [
        SaveModelCallback(learn, monitor='valid_loss'),
        CSVLogger(learn, filename='log'),
    ]

    learn.fit_one_cycle(num_epochs, lr, callbacks=callbacks)

    # Plot predictions and sync
    plot_preds(data, learn, output_dir)

    if s3_data:
        sync_to_dir(output_dir, remote_output_uri)
Пример #2
0
    def train(self):
        last_model_path = join(self.output_dir, 'last_model.pth')
        start_epoch = 0

        log_path = join(self.output_dir, 'log.csv')
        train_state_path = join(self.output_dir, 'train_state.json')
        if isfile(train_state_path):
            print('Resuming from checkpoint: {}\n'.format(last_model_path))
            train_state = file_to_json(train_state_path)
            start_epoch = train_state['epoch'] + 1
            self.model.load_state_dict(
                torch.load(last_model_path, map_location=self.device))

        metric_names = ['precision', 'recall', 'f1', 'accuracy']
        if not isfile(log_path):
            with open(log_path, 'w') as log_file:
                log_writer = csv.writer(log_file)
                row = ['epoch', 'time', 'train_loss'] + metric_names
                log_writer.writerow(row)

        for epoch in range(start_epoch, self.cfg.solver.num_epochs):
            start = time.time()
            train_loss = self.train_epoch(self.databunch.train_dl)
            end = time.time()
            epoch_time = datetime.timedelta(seconds=end - start)
            if self.epoch_scheduler:
                self.epoch_scheduler.step()

            print('----------------------------------------')
            print('epoch: {}'.format(epoch), flush=True)
            print('train loss: {}'.format(train_loss), flush=True)
            print('elapsed: {}'.format(epoch_time), flush=True)

            metrics = self.validate(self.databunch.valid_dl)
            print('validation metrics: {}'.format(metrics), flush=True)

            torch.save(self.model.state_dict(), last_model_path)
            train_state = {'epoch': epoch}
            json_to_file(train_state, train_state_path)

            with open(log_path, 'a') as log_file:
                log_writer = csv.writer(log_file)
                row = [epoch, epoch_time, train_loss]
                row += [metrics[k] for k in metric_names]
                log_writer.writerow(row)

            if (self.cfg.output_uri.startswith('s3://')
                    and ((epoch + 1) % self.cfg.solver.sync_interval == 0)):
                sync_to_dir(self.output_dir, self.cfg.output_uri)
Пример #3
0
 def on_epoch_end(self, **kwargs):
     if (kwargs['epoch'] + 1) % self.sync_interval == 0:
         sync_to_dir(self.from_dir, self.to_uri, delete=True)
Пример #4
0
def train_loop(cfg, databunch, model, opt, device, output_dir):
    best_model_path = join(output_dir, 'best_model.pth')
    last_model_path = join(output_dir, 'last_model.pth')
    num_labels = len(databunch.label_names)

    best_metric = -1.0
    start_epoch = 0
    train_state_path = join(output_dir, 'train_state.json')
    log_path = join(output_dir, 'log.csv')
    if isfile(train_state_path):
        print('Resuming from checkpoint: {}\n'.format(last_model_path))
        train_state = file_to_json(train_state_path)
        start_epoch = train_state['epoch'] + 1
        best_metric = train_state['best_metric']
        model.load_state_dict(torch.load(last_model_path, map_location=device))

    if not isfile(log_path):
        with open(log_path, 'w') as log_file:
            log_writer = csv.writer(log_file)
            row = ['epoch'] + ['map50', 'time'] + model.subloss_names
            log_writer.writerow(row)

    step_scheduler, epoch_scheduler = build_scheduler(cfg, databunch, opt,
                                                      start_epoch)

    for epoch in range(start_epoch, cfg.solver.num_epochs):
        start = time.time()
        train_loss = train_epoch(cfg, model, device, databunch.train_dl, opt,
                                 step_scheduler, epoch_scheduler)
        end = time.time()
        epoch_time = datetime.timedelta(seconds=end - start)
        if epoch_scheduler:
            epoch_scheduler.step()

        print('----------------------------------------')
        print('epoch: {}'.format(epoch), flush=True)
        print('train loss: {}'.format(train_loss), flush=True)
        print('elapsed: {}'.format(epoch_time), flush=True)

        metrics = validate_epoch(cfg, model, device, databunch.valid_dl,
                                 num_labels)
        print('validation metrics: {}'.format(metrics), flush=True)
        '''
        if metrics['map50'] > best_metric:
            best_metric = metrics['map50']
            torch.save(model.state_dict(), best_model_path)
        '''
        torch.save(model.state_dict(), best_model_path)
        torch.save(model.state_dict(), last_model_path)

        train_state = {'epoch': epoch, 'best_metric': best_metric}
        json_to_file(train_state, train_state_path)

        with open(log_path, 'a') as log_file:
            log_writer = csv.writer(log_file)
            row = [epoch]
            row += [metrics['map50'], epoch_time]
            row += [train_loss[k] for k in model.subloss_names]
            log_writer.writerow(row)

        if (cfg.output_uri.startswith('s3://')
                and ((epoch + 1) % cfg.solver.sync_interval == 0)):
            sync_to_dir(output_dir, cfg.output_uri)
Пример #5
0
def train(config_path, opts):
    torch_cache_dir = '/opt/data/torch-cache'
    os.environ['TORCH_HOME'] = torch_cache_dir

    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    cfg = load_config(config_path, opts)
    print(cfg)
    print()

    # Setup data
    databunch = build_databunch(cfg, tmp_dir)
    output_dir = setup_output_dir(cfg, tmp_dir)
    shutil.copyfile(config_path, join(output_dir, 'config.yml'))
    print(databunch)
    print()

    plotter = build_plotter(cfg)
    if not cfg.predict_mode:
        plotter.plot_dataloaders(databunch, output_dir)

    # Setup model
    num_labels = len(databunch.label_names)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = build_model(cfg, num_labels)
    model.to(device)
    opt = build_optimizer(cfg, model)

    # TODO tensorboard, progress bar
    if cfg.model.init_weights:
        model.load_state_dict(
            torch.load(cfg.model.init_weights, map_location=device))

    if not cfg.predict_mode:
        if cfg.overfit_mode:
            overfit_loop(cfg, databunch, model, opt, device, output_dir)
        else:
            train_loop(cfg, databunch, model, opt, device, output_dir)

    if cfg.eval_train:
        print('\nEvaluating on train set...')
        metrics = validate_epoch(cfg, model, device, databunch.train_dl,
                                 num_labels)
        print('train metrics: {}'.format(metrics))
        json_to_file(metrics, join(output_dir, 'train_metrics.json'))

        print('\nPlotting training set predictions...')
        plotter.make_debug_plots(databunch.train_dl, model,
                                 databunch.label_names,
                                 join(output_dir, 'train_preds.zip'), cfg)

    print('\nEvaluating on test set...')
    metrics = validate_epoch(cfg, model, device, databunch.test_dl, num_labels)
    print('test metrics: {}'.format(metrics))
    json_to_file(metrics, join(output_dir, 'test_metrics.json'))

    print('\nPlotting test set predictions...')
    plotter.make_debug_plots(databunch.test_dl, model, databunch.label_names,
                             join(output_dir, 'test_preds.zip'), cfg)

    if cfg.output_uri.startswith('s3://'):
        sync_to_dir(output_dir, cfg.output_uri)
Пример #6
0
def main(test, s3_data, batch, debug):
    if batch:
        run_on_batch(test, debug)

    # Setup options
    bs = 16
    size = 256
    num_workers = 4
    num_epochs = 100
    lr = 1e-4
    # for size 256
    # Subtract 2 because there's no padding on final convolution
    grid_sz = 8 - 2

    if test:
        bs = 8
        size = 128
        num_debug_images = 32
        num_workers = 0
        num_epochs = 1
        # for size 128
        grid_sz = 4 - 2

    # Setup data
    make_dir(output_dir)

    data_dir = untar_data(URLs.PASCAL_2007, dest='/opt/data/pascal2007/data')
    img_path = data_dir/'train/'
    trn_path = data_dir/'train.json'
    trn_images, trn_lbl_bbox = get_annotations(trn_path)
    val_path = data_dir/'valid.json'
    val_images, val_lbl_bbox = get_annotations(val_path)

    images, lbl_bbox = trn_images+val_images, trn_lbl_bbox+val_lbl_bbox
    img2bbox = dict(zip(images, lbl_bbox))
    get_y_func = lambda o: img2bbox[o.name]

    with open(trn_path) as f:
        d = json.load(f)
        classes = sorted(d['categories'], key=lambda x: x['id'])
        classes = [x['name'] for x in classes]
        classes = ['background'] + classes
        num_classes = len(classes)

    anc_sizes = torch.tensor([
        [1, 1],
        [2, 2],
        [3, 3],
        [3, 1],
        [1, 3]], dtype=torch.float32)
    grid = ObjectDetectionGrid(grid_sz, anc_sizes, num_classes)
    score_thresh = 0.1
    iou_thresh = 0.8

    class MyObjectCategoryList(ObjectCategoryList):
        def analyze_pred(self, pred):
            boxes, labels, _ = grid.get_preds(
                pred.unsqueeze(0), score_thresh=score_thresh,
                iou_thresh=iou_thresh)
            return (boxes[0], labels[0])

    class MyObjectItemList(ObjectItemList):
        _label_cls = MyObjectCategoryList

    def get_data(bs, size, ):
        src = MyObjectItemList.from_folder(img_path)
        if test:
            src = src[0:num_debug_images]
        src = src.split_by_files(val_images)
        src = src.label_from_func(get_y_func, classes=classes)
        src = src.transform(get_transforms(), size=size, tfm_y=True)
        return src.databunch(path=data_dir, bs=bs, collate_fn=bb_pad_collate,
                             num_workers=num_workers)

    data = get_data(bs, size)
    print(data)
    plot_data(data, output_dir)

    # Setup model
    model = ObjectDetectionModel(grid)

    def loss(out, gt_boxes, gt_classes):
        gt = model.grid.encode(gt_boxes, gt_classes)
        box_loss, class_loss = model.grid.compute_losses(out, gt)
        return box_loss + class_loss

    metrics = [F1(grid, score_thresh=score_thresh, iou_thresh=iou_thresh)]
    learn = Learner(data, model, metrics=metrics, loss_func=loss,
                    path=output_dir)
    callbacks = [
        CSVLogger(learn, filename='log')
    ]
    # model.freeze_body()
    learn.fit_one_cycle(num_epochs, lr, callbacks=callbacks)

    plot_preds(data, learn, output_dir)

    if s3_data:
        sync_to_dir(output_dir, output_uri)
Пример #7
0
def train(config_path, opts):
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    cfg = load_config(config_path, opts)
    print(cfg)

    # Setup data
    databunch, full_databunch = build_databunch(cfg, tmp_dir)
    output_dir = setup_output_dir(cfg, tmp_dir)
    print(full_databunch)

    plotter = build_plotter(cfg)
    if not cfg.lr_find_mode and not cfg.predict_mode:
        plotter.plot_data(databunch, output_dir)

    # Setup model
    num_labels = databunch.c
    model = build_model(cfg, num_labels)
    metrics = [CocoMetric(num_labels)]
    learn = Learner(databunch, model, path=output_dir, metrics=metrics)
    fastai.basic_train.loss_batch = loss_batch
    best_model_path = join(output_dir, 'best_model.pth')
    last_model_path = join(output_dir, 'last_model.pth')

    # Train model
    callbacks = [
        MyCSVLogger(learn, filename='log'),
        SubLossMetric(learn, model.subloss_names)
    ]

    if cfg.output_uri.startswith('s3://'):
        callbacks.append(
            SyncCallback(output_dir, cfg.output_uri, cfg.solver.sync_interval))

    if cfg.model.init_weights:
        device = next(model.parameters()).device
        model.load_state_dict(
            torch.load(cfg.model.init_weights, map_location=device))

    if not cfg.predict_mode:
        if cfg.overfit_mode:
            learn.fit_one_cycle(cfg.solver.num_epochs, cfg.solver.lr, callbacks=callbacks)
            torch.save(learn.model.state_dict(), best_model_path)
            learn.model.eval()
            print('Validating on training set...')
            learn.validate(full_databunch.train_dl, metrics=metrics)
        else:
            tb_logger = TensorboardLogger(learn, 'run')
            tb_logger.set_extra_args(
                model.subloss_names, cfg.overfit_mode)

            extra_callbacks = [
                MySaveModelCallback(
                    learn, best_model_path, monitor='coco_metric', every='improvement'),
                MySaveModelCallback(learn, last_model_path, every='epoch'),
                TrackEpochCallback(learn),
            ]
            callbacks.extend(extra_callbacks)
            if cfg.lr_find_mode:
                learn.lr_find()
                learn.recorder.plot(suggestion=True, return_fig=True)
                lr = learn.recorder.min_grad_lr
                print('lr_find() found lr: {}'.format(lr))
                exit()

            learn.fit_one_cycle(cfg.solver.num_epochs, cfg.solver.lr, callbacks=callbacks)
            print('Validating on full validation set...')
            learn.validate(full_databunch.valid_dl, metrics=metrics)
    else:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model.load_state_dict(
            torch.load(join(output_dir, 'best_model.pth'), map_location=device))
        model.eval()
        plot_dataset = databunch.train_ds

    print('Plotting predictions...')
    plot_dataset = databunch.train_ds if cfg.overfit_mode else databunch.valid_ds
    plotter.make_debug_plots(plot_dataset, model, databunch.classes, output_dir)
    if cfg.output_uri.startswith('s3://'):
        sync_to_dir(output_dir, cfg.output_uri)
Пример #8
0
def main(config_path, opts):
    # Load config and setup output_dir.
    torch_cache_dir = '/opt/data/torch-cache'
    os.environ['TORCH_HOME'] = torch_cache_dir
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    cfg = load_config(config_path, opts)
    if cfg.output_uri.startswith('s3://'):
        output_dir = get_local_path(cfg.output_uri, tmp_dir)
        make_dir(output_dir, force_empty=True)
        if not cfg.overfit_mode:
            sync_from_dir(cfg.output_uri, output_dir)
    else:
        output_dir = cfg.output_uri
        make_dir(output_dir)
    shutil.copyfile(config_path, join(output_dir, 'config.yml'))

    print(cfg)
    print()

    # Setup databunch and plot.
    databunch = build_databunch(cfg)
    print(databunch)
    print()
    if not cfg.predict_mode:
        databunch.plot_dataloaders(output_dir)

    # Setup learner.
    num_labels = len(databunch.label_names)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = build_model(cfg)
    model.to(device)
    if cfg.model.init_weights:
        model.load_state_dict(
            torch.load(cfg.model.init_weights, map_location=device))

    opt = build_optimizer(cfg, model)
    loss_fn = torch.nn.CrossEntropyLoss()
    start_epoch = 0
    train_state_path = join(output_dir, 'train_state.json')
    if isfile(train_state_path):
        train_state = file_to_json(train_state_path)
        start_epoch = train_state['epoch'] + 1
    num_samples = len(databunch.train_ds)
    step_scheduler, epoch_scheduler = build_scheduler(cfg, num_samples, opt,
                                                      start_epoch)
    learner = Learner(cfg, databunch, output_dir, model, loss_fn, opt, device,
                      epoch_scheduler, step_scheduler)

    # Train
    if not cfg.predict_mode:
        if cfg.overfit_mode:
            learner.overfit()
        else:
            learner.train()

    # Evaluate on test set and plot.
    if cfg.eval_train:
        print('\nEvaluating on train set...')
        metrics = learner.validate_epoch(databunch.train_dl)
        print('train metrics: {}'.format(metrics))
        json_to_file(metrics, join(output_dir, 'train_metrics.json'))

        print('\nPlotting training set predictions...')
        learner.plot_preds(databunch.train_dl,
                           join(output_dir, 'train_preds.png'))

    print('\nEvaluating on test set...')
    metrics = learner.validate(databunch.test_dl)
    print('test metrics: {}'.format(metrics))
    json_to_file(metrics, join(output_dir, 'test_metrics.json'))

    print('\nPlotting test set predictions...')
    learner.plot_preds(databunch.test_dl, join(output_dir, 'test_preds.png'))

    if cfg.output_uri.startswith('s3://'):
        sync_to_dir(output_dir, cfg.output_uri)