コード例 #1
0
ファイル: data.py プロジェクト: lewfish/mlx
def get_label_names(coco_path):
    categories = file_to_json(coco_path)['categories']
    label2name = dict([(cat['id'], cat['name']) for cat in categories])
    labels = ['background'
              ] + [label2name[i] for i in range(1,
                                                len(label2name) + 1)]
    return labels
コード例 #2
0
ファイル: data.py プロジェクト: lewfish/mlx
    def __init__(self, img_dir, annotation_uris, transforms=None):
        self.img_dir = img_dir
        self.annotation_uris = annotation_uris
        self.transforms = transforms

        self.imgs = []
        self.img2id = {}
        self.id2img = {}
        self.id2boxes = defaultdict(lambda: [])
        self.id2labels = defaultdict(lambda: [])
        self.label2name = {}
        for annotation_uri in annotation_uris:
            ann_json = file_to_json(annotation_uri)
            for img in ann_json['images']:
                self.imgs.append(img['file_name'])
                self.img2id[img['file_name']] = img['id']
                self.id2img[img['id']] = img['file_name']
            for ann in ann_json['annotations']:
                img_id = ann['image_id']
                box = ann['bbox']
                label = ann['category_id']
                self.id2boxes[img_id].append(box)
                self.id2labels[img_id].append(label)

        random.seed(1234)
        random.shuffle(self.imgs)
        self.id2boxes = dict([(id, boxes)
                              for id, boxes in self.id2boxes.items()])
        self.id2labels = dict([(id, labels)
                               for id, labels in self.id2labels.items()])
コード例 #3
0
ファイル: learner.py プロジェクト: lewfish/mlx
    def train(self):
        last_model_path = join(self.output_dir, 'last_model.pth')
        start_epoch = 0

        log_path = join(self.output_dir, 'log.csv')
        train_state_path = join(self.output_dir, 'train_state.json')
        if isfile(train_state_path):
            print('Resuming from checkpoint: {}\n'.format(last_model_path))
            train_state = file_to_json(train_state_path)
            start_epoch = train_state['epoch'] + 1
            self.model.load_state_dict(
                torch.load(last_model_path, map_location=self.device))

        metric_names = ['precision', 'recall', 'f1', 'accuracy']
        if not isfile(log_path):
            with open(log_path, 'w') as log_file:
                log_writer = csv.writer(log_file)
                row = ['epoch', 'time', 'train_loss'] + metric_names
                log_writer.writerow(row)

        for epoch in range(start_epoch, self.cfg.solver.num_epochs):
            start = time.time()
            train_loss = self.train_epoch(self.databunch.train_dl)
            end = time.time()
            epoch_time = datetime.timedelta(seconds=end - start)
            if self.epoch_scheduler:
                self.epoch_scheduler.step()

            print('----------------------------------------')
            print('epoch: {}'.format(epoch), flush=True)
            print('train loss: {}'.format(train_loss), flush=True)
            print('elapsed: {}'.format(epoch_time), flush=True)

            metrics = self.validate(self.databunch.valid_dl)
            print('validation metrics: {}'.format(metrics), flush=True)

            torch.save(self.model.state_dict(), last_model_path)
            train_state = {'epoch': epoch}
            json_to_file(train_state, train_state_path)

            with open(log_path, 'a') as log_file:
                log_writer = csv.writer(log_file)
                row = [epoch, epoch_time, train_loss]
                row += [metrics[k] for k in metric_names]
                log_writer.writerow(row)

            if (self.cfg.output_uri.startswith('s3://')
                    and ((epoch + 1) % self.cfg.solver.sync_interval == 0)):
                sync_to_dir(self.output_dir, self.cfg.output_uri)
コード例 #4
0
def train_loop(cfg, databunch, model, opt, device, output_dir):
    best_model_path = join(output_dir, 'best_model.pth')
    last_model_path = join(output_dir, 'last_model.pth')
    num_labels = len(databunch.label_names)

    best_metric = -1.0
    start_epoch = 0
    train_state_path = join(output_dir, 'train_state.json')
    log_path = join(output_dir, 'log.csv')
    if isfile(train_state_path):
        print('Resuming from checkpoint: {}\n'.format(last_model_path))
        train_state = file_to_json(train_state_path)
        start_epoch = train_state['epoch'] + 1
        best_metric = train_state['best_metric']
        model.load_state_dict(torch.load(last_model_path, map_location=device))

    if not isfile(log_path):
        with open(log_path, 'w') as log_file:
            log_writer = csv.writer(log_file)
            row = ['epoch'] + ['map50', 'time'] + model.subloss_names
            log_writer.writerow(row)

    step_scheduler, epoch_scheduler = build_scheduler(cfg, databunch, opt,
                                                      start_epoch)

    for epoch in range(start_epoch, cfg.solver.num_epochs):
        start = time.time()
        train_loss = train_epoch(cfg, model, device, databunch.train_dl, opt,
                                 step_scheduler, epoch_scheduler)
        end = time.time()
        epoch_time = datetime.timedelta(seconds=end - start)
        if epoch_scheduler:
            epoch_scheduler.step()

        print('----------------------------------------')
        print('epoch: {}'.format(epoch), flush=True)
        print('train loss: {}'.format(train_loss), flush=True)
        print('elapsed: {}'.format(epoch_time), flush=True)

        metrics = validate_epoch(cfg, model, device, databunch.valid_dl,
                                 num_labels)
        print('validation metrics: {}'.format(metrics), flush=True)
        '''
        if metrics['map50'] > best_metric:
            best_metric = metrics['map50']
            torch.save(model.state_dict(), best_model_path)
        '''
        torch.save(model.state_dict(), best_model_path)
        torch.save(model.state_dict(), last_model_path)

        train_state = {'epoch': epoch, 'best_metric': best_metric}
        json_to_file(train_state, train_state_path)

        with open(log_path, 'a') as log_file:
            log_writer = csv.writer(log_file)
            row = [epoch]
            row += [metrics['map50'], epoch_time]
            row += [train_loss[k] for k in model.subloss_names]
            log_writer.writerow(row)

        if (cfg.output_uri.startswith('s3://')
                and ((epoch + 1) % cfg.solver.sync_interval == 0)):
            sync_to_dir(output_dir, cfg.output_uri)
コード例 #5
0
ファイル: main.py プロジェクト: lewfish/mlx
def main(config_path, opts):
    # Load config and setup output_dir.
    torch_cache_dir = '/opt/data/torch-cache'
    os.environ['TORCH_HOME'] = torch_cache_dir
    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    cfg = load_config(config_path, opts)
    if cfg.output_uri.startswith('s3://'):
        output_dir = get_local_path(cfg.output_uri, tmp_dir)
        make_dir(output_dir, force_empty=True)
        if not cfg.overfit_mode:
            sync_from_dir(cfg.output_uri, output_dir)
    else:
        output_dir = cfg.output_uri
        make_dir(output_dir)
    shutil.copyfile(config_path, join(output_dir, 'config.yml'))

    print(cfg)
    print()

    # Setup databunch and plot.
    databunch = build_databunch(cfg)
    print(databunch)
    print()
    if not cfg.predict_mode:
        databunch.plot_dataloaders(output_dir)

    # Setup learner.
    num_labels = len(databunch.label_names)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    model = build_model(cfg)
    model.to(device)
    if cfg.model.init_weights:
        model.load_state_dict(
            torch.load(cfg.model.init_weights, map_location=device))

    opt = build_optimizer(cfg, model)
    loss_fn = torch.nn.CrossEntropyLoss()
    start_epoch = 0
    train_state_path = join(output_dir, 'train_state.json')
    if isfile(train_state_path):
        train_state = file_to_json(train_state_path)
        start_epoch = train_state['epoch'] + 1
    num_samples = len(databunch.train_ds)
    step_scheduler, epoch_scheduler = build_scheduler(cfg, num_samples, opt,
                                                      start_epoch)
    learner = Learner(cfg, databunch, output_dir, model, loss_fn, opt, device,
                      epoch_scheduler, step_scheduler)

    # Train
    if not cfg.predict_mode:
        if cfg.overfit_mode:
            learner.overfit()
        else:
            learner.train()

    # Evaluate on test set and plot.
    if cfg.eval_train:
        print('\nEvaluating on train set...')
        metrics = learner.validate_epoch(databunch.train_dl)
        print('train metrics: {}'.format(metrics))
        json_to_file(metrics, join(output_dir, 'train_metrics.json'))

        print('\nPlotting training set predictions...')
        learner.plot_preds(databunch.train_dl,
                           join(output_dir, 'train_preds.png'))

    print('\nEvaluating on test set...')
    metrics = learner.validate(databunch.test_dl)
    print('test metrics: {}'.format(metrics))
    json_to_file(metrics, join(output_dir, 'test_metrics.json'))

    print('\nPlotting test set predictions...')
    learner.plot_preds(databunch.test_dl, join(output_dir, 'test_preds.png'))

    if cfg.output_uri.startswith('s3://'):
        sync_to_dir(output_dir, cfg.output_uri)