Example #1
0
def main(input_file, output_file, segmentation_model,
         seg_model_name, transform_model_name, frame_skip, batch_size,
         max_frames):
    """
    used to create a gif with lines segmented from a video files
    """
    assert output_file.endswith('.gif'), 'Make sure output_file is a .gif'

    print('Loading models..')
    num_classes = 4
    input_channels = 3
    model_seg = get_seg_model(seg_model_name, num_classes, input_channels).to(device)
    model_seg.load(segmentation_model)

    print('Loading data..')
    cap = cv2.VideoCapture(input_file)
    images = torch.tensor(get_frames(cap, frame_skip, max_frames)).to(torch.uint8)
    num_images = images.shape[0]
    data_iterator = DataLoader(
        dataset=CustomDataset(images),
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        drop_last=False
    )
    print('\tNumber of frames to convert:\t{} (frame skip: {})'.format(num_images, frame_skip))

    print('Converting..')
    model_seg.eval()
    batch_count = 0
    gif_frames = []
    with torch.no_grad():
        start = timer()
        for images in data_iterator:
            batch_count += 1

            # get segmentation
            seg_logits = model_seg(images.to(device))
            seg_preds = torch.argmax(seg_logits, dim=1).cpu()

            source = torch.mul(images.cpu(), 255).to(torch.uint8)
            segmented = logit_to_img(seg_preds.cpu().numpy()).transpose(0, 3, 1, 2)
            segmented = torch.mul(torch.tensor(segmented), 255).to(torch.uint8)

            # convert torch predictions to frames of grid
            gif_frames.extend(convert_batch_to_frames(source, segmented))

            if batch_count % 50 == 0:
                print('\tframe {} / {} - {:.2f} secs'.format(
                    batch_count*batch_size, num_images, timer() - start)
                )
                start = timer()

        del images, source, segmented,
        # convert sequence of frames into gif
        print('Saving {}..'.format(output_file))
        imageio.mimsave(output_file, gif_frames, fps=29.97/frame_skip, subrectangles=True)
def main(data_sim_dir, data_real_dir, data_label_dir, save_dir, visdom_dir,
         batch_size, config_file, seg_model_name, early_stop_patience, server,
         port, reload, run_name):

    print('Loading data..')
    num_classes = 4
    input_channels = 3
    real_data = PartitionProvider(input_dir=data_real_dir,
                                  label_dir=None,
                                  num_workers=0,
                                  partition_batch_size=batch_size,
                                  partition_num_workers=2)

    sim_data = PartitionProvider(input_dir=data_sim_dir,
                                 label_dir=data_label_dir,
                                 num_workers=0,
                                 partition_batch_size=batch_size,
                                 partition_num_workers=2)

    print('Building model & loading on GPU (if applicable)..')
    seg_model = models.get_seg_model(seg_model_name, num_classes,
                                     input_channels).to(device)
    if save_dir:
        seg_model.save(os.path.join(save_dir, '{}.pth'.format(seg_model.name)))

    # adjusted class weights [black, white, red, yellow] see README.md
    class_weights = torch.tensor([0.0051, 0.0551, 0.6538, 0.2860]).to(device)

    loss = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(seg_model.parameters())

    print('Initializing misc..')
    batch_count = 0
    partition_count = 0
    results = dict()
    early_stopper = EarlyStopper('accuracy', early_stop_patience)
    visualiser = vis.Visualiser(server, port, run_name, reload, visdom_dir)

    print('Starting training..')
    for epoch_id in count(start=1):
        for sim_partition in sim_data.train_partition_iterator:
            start = timer()
            seg_model.train()
            partition_loss = 0
            batch_per_part = 0
            partition_count += 1

            sim_data_train_iterator = sim_data.get_train_iterator(
                sim_partition)

            for batch_id, batch in enumerate(sim_data_train_iterator):
                batch_count += 1
                batch_per_part += 1

                input, labels = batch
                logits = seg_model(input.to(device))

                optimizer.zero_grad()
                loss_seg = loss(
                    logits.permute(0, 2, 3,
                                   1).contiguous().view(-1, num_classes),
                    labels.view(-1).to(device))
                loss_seg.backward()
                optimizer.step()

                partition_loss += loss_seg.item()
                X, data = np.array([batch_count]), np.array(
                    [loss_seg.detach().to('cpu').numpy()])
                visualiser.plot(X,
                                data,
                                title='Loss per batch',
                                legend=['Loss'],
                                iteration=2,
                                update='append')

            del logits
            del loss_seg
            optimizer.zero_grad()
            torch.cuda.empty_cache()

            results['partition_avg_loss'] = np.divide(partition_loss,
                                                      batch_per_part)
            results.update(evaluate(seg_model, sim_data, device))
            early_stopper.update(results, epoch_id, batch_count)
            log_and_viz_results(results, epoch_id, batch_count,
                                partition_count, visualiser, start)

            if early_stopper.new_best and save_dir:
                seg_model.save(
                    os.path.join(save_dir, '{}.pth'.format(seg_model.name)))

            if early_stopper.stop:
                early_stopper.print_stop()
                return
def main(data_sim_dir, data_real_dir, data_label_dir, save_dir, visdom_dir,
         batch_size, config_file, discr_model_name, gen_model_name,
         early_stop_patience, max_num_batch, server, port, reload, run_name,
         batch_per_eval, batch_per_save, seg_model_path, seg_model_name,
         content_weight):

    print('Loading data..')
    num_classes = 4
    input_channels = 3
    real_data = InfiniteProviderFromPartitions(
        input_dir=data_real_dir,
        label_dir=None,
        num_workers=0,
        partition_batch_size=batch_size,
        partition_num_workers=2
    )

    sim_data = InfiniteProviderFromPartitions(
        input_dir=data_sim_dir,
        label_dir=data_label_dir,
        num_workers=0,
        partition_batch_size=batch_size,
        partition_num_workers=2
    )

    real_data.init_train_iterator()
    sim_data.init_train_iterator()

    print('Building model & loading on GPU (if applicable)..')
    if seg_model_path:
        assert seg_model_name in seg_model_path
        model_seg = models.get_seg_model(seg_model_name, num_classes, input_channels).to(device)
        model_seg.load(seg_model_path)
        model_seg.name = 'segnet_transfer'

    model_discr = models.get_discriminator_model(discr_model_name, model_seg.size_bottleneck,
            stride=1, flat_size=128*5*3).to(device)
    if save_dir:
        model_discr.save(os.path.join(save_dir, '{}_{}.pth'.format(model_discr.name, 0)))

    obj_adv = nn.BCELoss()
    label_true = 1
    label_fake = 0

    optim_gen = optim.Adam(filter(lambda x: 'd' not in x.__class__.__name__, model_seg.parameters()))
    optim_discr = optim.Adam(model_discr.parameters())

    print('Initializing misc..')
    batch_count = 0
    results = dict()
    results['accuracy'] = 0
    visualiser = vis.Visualiser(server, port, run_name, reload, visdom_dir)
    class_weights = torch.tensor([0.0051, 0.0551, 0.6538, 0.2860]).to(device)
    loss_seg = nn.CrossEntropyLoss(weight=class_weights)

    print('Starting training..')
    for eval_count in count(start=1):

        batch_since_eval = 0
        loss_discr_true_sum = 0
        loss_discr_fake_sum = 0
        loss_gen_sum = 0
        model_seg.train()
        model_discr.train()
        start = timer()

        while batch_since_eval < batch_per_eval:

            batch_real = next(real_data).to(device)
            batch_sim, label_sim = next(sim_data)
            batch_sim, label_sim = batch_sim.to(device), label_sim.to(device)
            assert batch_sim.shape[0] == batch_real.shape[0]
            b_size = batch_sim.shape[0]
            batch_count += 1

            optim_discr.zero_grad()
            # train discriminator on true data (logD(x))
            emb_sim = model_seg(batch_sim.to(device), bottleneck=True)
            scores_true = model_discr(emb_sim.detach())
            labels = torch.full((b_size, ), label_true).to(device)
            loss_d_true = obj_adv(scores_true.view(-1), labels.to(device))
            loss_d_true.backward()

            # train discriminator on fake data (log(1-D(G(x)))
            batch_fake = model_seg(batch_real.to(device), bottleneck=True)
            scores_fake = model_discr(batch_fake.detach())
            labels.fill_(label_fake)
            loss_d_fake = obj_adv(scores_fake.view(-1), labels)
            loss_d_fake.backward()

            optim_discr.step()

            optim_gen.zero_grad()
            # train generator
            scores_fake = model_discr(batch_fake)
            labels.fill_(label_true)
            loss_g_fake = obj_adv(scores_fake.view(-1), labels)
            logits = model_seg(batch_sim)
            seg_loss = loss_seg(logits.permute(0, 2, 3, 1).contiguous().view(-1, num_classes),
                                label_sim.view(-1).to(device))
            (loss_g_fake + seg_loss).backward()

            optim_gen.step()

            loss_discr_fake_sum += loss_d_fake.item()
            loss_discr_true_sum += loss_d_true.item()
            loss_gen_sum += loss_g_fake.item()
            loss_tot = loss_d_fake.item() + loss_d_true.item()

            X, data = np.array([batch_count]), np.array([loss_tot])
            visualiser.plot(X, data, title='Loss per batch', legend=['Loss'], iteration=2, update='append')

            batch_since_eval += 1

        # DO EVAL
        torch.cuda.empty_cache()
        eval_count += 1
        results['loss_gen'] = np.divide(loss_gen_sum, batch_since_eval)
        results['loss_discr_fake'] = np.divide(loss_discr_fake_sum, batch_since_eval)
        results['loss_discr_true'] = np.divide(loss_discr_true_sum, batch_since_eval)

        if seg_model_path:
            results.update(evaluate_segtransfer(model_seg, sim_data.get_valid_iterator(),
                                                real_data.get_valid_iterator(), device))
            # early_stopper.update(results, epoch_id=eval_count, batch_id=batch_count)
        log_and_viz_results(results, batch_count, eval_count, visualiser, start, save_dir)

        if save_dir and batch_count % batch_per_save == 0:
            # model_gen.save(os.path.join(save_dir, '{}_{}.pth'.format(model_gen.name, batch_count)))
            model_seg.save(os.path.join(save_dir, '{}_{}.pth'.format(model_seg.name, batch_count)))
            model_discr.save(os.path.join(save_dir, '{}_{}.pth'.format(model_discr.name, batch_count)))

        if batch_count >= max_num_batch:
            print('Stopping training..')
            break
def main(data_sim_dir, data_real_dir, data_label_dir, save_dir, visdom_dir,
         batch_size, config_file, discr_model_name, gen_model_name,
         early_stop_patience, max_num_batch, server, port, reload, run_name,
         batch_per_eval, batch_per_save, seg_model_path, seg_model_name,
         content_weight):

    print('Loading data..')
    num_classes = 4
    input_channels = 3
    real_data = InfiniteProviderFromPartitions(input_dir=data_real_dir,
                                               label_dir=None,
                                               num_workers=0,
                                               partition_batch_size=batch_size,
                                               partition_num_workers=2)

    sim_data = InfiniteProviderFromPartitions(input_dir=data_sim_dir,
                                              label_dir=data_label_dir,
                                              num_workers=0,
                                              partition_batch_size=batch_size,
                                              partition_num_workers=2)

    real_data.init_train_iterator()
    sim_data.init_train_iterator()

    print('Building model & loading on GPU (if applicable)..')
    model_gen = models.get_generator_model(gen_model_name,
                                           input_channels).to(device)
    model_discr = models.get_discriminator_model(discr_model_name,
                                                 input_channels).to(device)
    if save_dir:
        model_gen.save(
            os.path.join(save_dir, '{}_{}.pth'.format(model_gen.name, 0)))
        model_discr.save(
            os.path.join(save_dir, '{}_{}.pth'.format(model_discr.name, 0)))

    if seg_model_path:
        assert seg_model_name in seg_model_path
        model_seg = models.get_seg_model(seg_model_name, num_classes,
                                         input_channels).to(device)
        model_seg.load(seg_model_path)

    obj_adv = nn.BCELoss()
    content_obj = nn.MSELoss(reduction='none')
    label_true = 1
    label_fake = 0

    optim_gen = optim.Adam(model_gen.parameters())
    optim_discr = optim.Adam(model_discr.parameters())

    print('Initializing misc..')
    batch_count = 0
    results = dict()
    results['accuracy'] = 0
    early_stopper = EarlyStopper('accuracy', early_stop_patience)
    visualiser = vis.Visualiser(server, port, run_name, reload, visdom_dir)

    print('Starting training..')
    for eval_count in count(start=1):

        batch_since_eval = 0
        loss_discr_true_sum = 0
        loss_discr_fake_sum = 0
        loss_gen_style_sum = 0
        loss_gen_content_sum = 0
        model_gen.train()
        model_discr.train()
        start = timer()

        while batch_since_eval < batch_per_eval:

            batch_real = next(real_data).to(device)
            batch_sim = next(sim_data)
            batch_sim = batch_sim[0].to(device)
            assert batch_sim.shape[0] == batch_real.shape[0]
            b_size = batch_sim.shape[0]
            batch_count += 1

            loss_d = 0
            optim_discr.zero_grad()
            # train discriminator on true data (logD(x))
            scores_true = model_discr(batch_sim)
            labels = torch.full((b_size, ), label_true).to(device)
            loss_d_true = obj_adv(scores_true.view(-1), labels.to(device))
            loss_d_true.backward()

            # train discriminator on fake data (log(1-D(G(x)))
            batch_fake = model_gen(batch_real)
            scores_fake = model_discr(batch_fake.detach())
            labels.fill_(label_fake)
            loss_d_fake = obj_adv(scores_fake.view(-1), labels)
            loss_d_fake.backward()

            optim_discr.step()

            loss_g = 0
            optim_gen.zero_grad()
            # train generator
            scores_fake = model_discr(batch_fake)
            labels.fill_(label_true)
            loss_g_style = obj_adv(scores_fake.view(-1), labels)
            loss_g_tot = loss_g_style
            loss_g_content = content_weight * content_obj(
                batch_fake, batch_real).sum((1, 2, 3)).mean()
            loss_g_tot += loss_g_content
            loss_g_tot.backward()

            optim_gen.step()

            loss_discr_fake_sum += loss_d_fake.item()
            loss_discr_true_sum += loss_d_true.item()
            loss_gen_style_sum += loss_g_style.item()
            loss_gen_content_sum += loss_g_content.item()
            loss_tot = loss_d_fake.item() + loss_d_true.item(
            ) + loss_g_style.item() + loss_g_content.item()

            X, data = np.array([batch_count]), np.array([loss_tot])
            visualiser.plot(X,
                            data,
                            title='Loss per batch',
                            legend=['Loss total'],
                            iteration=2,
                            update='append')

            batch_since_eval += 1

        # DO EVAL
        torch.cuda.empty_cache()
        eval_count += 1
        results['loss_gen_style'] = np.divide(loss_gen_style_sum,
                                              batch_since_eval)
        results['loss_gen_content'] = np.divide(loss_gen_content_sum,
                                                batch_since_eval)
        results['loss_discr_fake'] = np.divide(loss_discr_fake_sum,
                                               batch_since_eval)
        results['loss_discr_true'] = np.divide(loss_discr_true_sum,
                                               batch_since_eval)

        if seg_model_path:
            results.update(
                evaluate_transfer(model_seg, model_gen,
                                  real_data.get_valid_iterator(), device))
            # early_stopper.update(results, epoch_id=eval_count, batch_id=batch_count)
        log_and_viz_results(results, batch_count, eval_count, visualiser,
                            start, save_dir)

        if save_dir and batch_count % batch_per_save == 0:
            model_gen.save(
                os.path.join(save_dir,
                             '{}_{}.pth'.format(model_gen.name, batch_count)))
            model_discr.save(
                os.path.join(save_dir,
                             '{}_{}.pth'.format(model_discr.name,
                                                batch_count)))

        if batch_count >= max_num_batch:
            print('Stopping training..')
            break