def main(input_file, output_file, segmentation_model, seg_model_name, transform_model_name, frame_skip, batch_size, max_frames): """ used to create a gif with lines segmented from a video files """ assert output_file.endswith('.gif'), 'Make sure output_file is a .gif' print('Loading models..') num_classes = 4 input_channels = 3 model_seg = get_seg_model(seg_model_name, num_classes, input_channels).to(device) model_seg.load(segmentation_model) print('Loading data..') cap = cv2.VideoCapture(input_file) images = torch.tensor(get_frames(cap, frame_skip, max_frames)).to(torch.uint8) num_images = images.shape[0] data_iterator = DataLoader( dataset=CustomDataset(images), batch_size=batch_size, shuffle=False, num_workers=2, drop_last=False ) print('\tNumber of frames to convert:\t{} (frame skip: {})'.format(num_images, frame_skip)) print('Converting..') model_seg.eval() batch_count = 0 gif_frames = [] with torch.no_grad(): start = timer() for images in data_iterator: batch_count += 1 # get segmentation seg_logits = model_seg(images.to(device)) seg_preds = torch.argmax(seg_logits, dim=1).cpu() source = torch.mul(images.cpu(), 255).to(torch.uint8) segmented = logit_to_img(seg_preds.cpu().numpy()).transpose(0, 3, 1, 2) segmented = torch.mul(torch.tensor(segmented), 255).to(torch.uint8) # convert torch predictions to frames of grid gif_frames.extend(convert_batch_to_frames(source, segmented)) if batch_count % 50 == 0: print('\tframe {} / {} - {:.2f} secs'.format( batch_count*batch_size, num_images, timer() - start) ) start = timer() del images, source, segmented, # convert sequence of frames into gif print('Saving {}..'.format(output_file)) imageio.mimsave(output_file, gif_frames, fps=29.97/frame_skip, subrectangles=True)
def main(data_sim_dir, data_real_dir, data_label_dir, save_dir, visdom_dir, batch_size, config_file, seg_model_name, early_stop_patience, server, port, reload, run_name): print('Loading data..') num_classes = 4 input_channels = 3 real_data = PartitionProvider(input_dir=data_real_dir, label_dir=None, num_workers=0, partition_batch_size=batch_size, partition_num_workers=2) sim_data = PartitionProvider(input_dir=data_sim_dir, label_dir=data_label_dir, num_workers=0, partition_batch_size=batch_size, partition_num_workers=2) print('Building model & loading on GPU (if applicable)..') seg_model = models.get_seg_model(seg_model_name, num_classes, input_channels).to(device) if save_dir: seg_model.save(os.path.join(save_dir, '{}.pth'.format(seg_model.name))) # adjusted class weights [black, white, red, yellow] see README.md class_weights = torch.tensor([0.0051, 0.0551, 0.6538, 0.2860]).to(device) loss = nn.CrossEntropyLoss(weight=class_weights) optimizer = optim.Adam(seg_model.parameters()) print('Initializing misc..') batch_count = 0 partition_count = 0 results = dict() early_stopper = EarlyStopper('accuracy', early_stop_patience) visualiser = vis.Visualiser(server, port, run_name, reload, visdom_dir) print('Starting training..') for epoch_id in count(start=1): for sim_partition in sim_data.train_partition_iterator: start = timer() seg_model.train() partition_loss = 0 batch_per_part = 0 partition_count += 1 sim_data_train_iterator = sim_data.get_train_iterator( sim_partition) for batch_id, batch in enumerate(sim_data_train_iterator): batch_count += 1 batch_per_part += 1 input, labels = batch logits = seg_model(input.to(device)) optimizer.zero_grad() loss_seg = loss( logits.permute(0, 2, 3, 1).contiguous().view(-1, num_classes), labels.view(-1).to(device)) loss_seg.backward() optimizer.step() partition_loss += loss_seg.item() X, data = np.array([batch_count]), np.array( [loss_seg.detach().to('cpu').numpy()]) visualiser.plot(X, data, title='Loss per batch', legend=['Loss'], iteration=2, update='append') del logits del loss_seg optimizer.zero_grad() torch.cuda.empty_cache() results['partition_avg_loss'] = np.divide(partition_loss, batch_per_part) results.update(evaluate(seg_model, sim_data, device)) early_stopper.update(results, epoch_id, batch_count) log_and_viz_results(results, epoch_id, batch_count, partition_count, visualiser, start) if early_stopper.new_best and save_dir: seg_model.save( os.path.join(save_dir, '{}.pth'.format(seg_model.name))) if early_stopper.stop: early_stopper.print_stop() return
def main(data_sim_dir, data_real_dir, data_label_dir, save_dir, visdom_dir, batch_size, config_file, discr_model_name, gen_model_name, early_stop_patience, max_num_batch, server, port, reload, run_name, batch_per_eval, batch_per_save, seg_model_path, seg_model_name, content_weight): print('Loading data..') num_classes = 4 input_channels = 3 real_data = InfiniteProviderFromPartitions( input_dir=data_real_dir, label_dir=None, num_workers=0, partition_batch_size=batch_size, partition_num_workers=2 ) sim_data = InfiniteProviderFromPartitions( input_dir=data_sim_dir, label_dir=data_label_dir, num_workers=0, partition_batch_size=batch_size, partition_num_workers=2 ) real_data.init_train_iterator() sim_data.init_train_iterator() print('Building model & loading on GPU (if applicable)..') if seg_model_path: assert seg_model_name in seg_model_path model_seg = models.get_seg_model(seg_model_name, num_classes, input_channels).to(device) model_seg.load(seg_model_path) model_seg.name = 'segnet_transfer' model_discr = models.get_discriminator_model(discr_model_name, model_seg.size_bottleneck, stride=1, flat_size=128*5*3).to(device) if save_dir: model_discr.save(os.path.join(save_dir, '{}_{}.pth'.format(model_discr.name, 0))) obj_adv = nn.BCELoss() label_true = 1 label_fake = 0 optim_gen = optim.Adam(filter(lambda x: 'd' not in x.__class__.__name__, model_seg.parameters())) optim_discr = optim.Adam(model_discr.parameters()) print('Initializing misc..') batch_count = 0 results = dict() results['accuracy'] = 0 visualiser = vis.Visualiser(server, port, run_name, reload, visdom_dir) class_weights = torch.tensor([0.0051, 0.0551, 0.6538, 0.2860]).to(device) loss_seg = nn.CrossEntropyLoss(weight=class_weights) print('Starting training..') for eval_count in count(start=1): batch_since_eval = 0 loss_discr_true_sum = 0 loss_discr_fake_sum = 0 loss_gen_sum = 0 model_seg.train() model_discr.train() start = timer() while batch_since_eval < batch_per_eval: batch_real = next(real_data).to(device) batch_sim, label_sim = next(sim_data) batch_sim, label_sim = batch_sim.to(device), label_sim.to(device) assert batch_sim.shape[0] == batch_real.shape[0] b_size = batch_sim.shape[0] batch_count += 1 optim_discr.zero_grad() # train discriminator on true data (logD(x)) emb_sim = model_seg(batch_sim.to(device), bottleneck=True) scores_true = model_discr(emb_sim.detach()) labels = torch.full((b_size, ), label_true).to(device) loss_d_true = obj_adv(scores_true.view(-1), labels.to(device)) loss_d_true.backward() # train discriminator on fake data (log(1-D(G(x))) batch_fake = model_seg(batch_real.to(device), bottleneck=True) scores_fake = model_discr(batch_fake.detach()) labels.fill_(label_fake) loss_d_fake = obj_adv(scores_fake.view(-1), labels) loss_d_fake.backward() optim_discr.step() optim_gen.zero_grad() # train generator scores_fake = model_discr(batch_fake) labels.fill_(label_true) loss_g_fake = obj_adv(scores_fake.view(-1), labels) logits = model_seg(batch_sim) seg_loss = loss_seg(logits.permute(0, 2, 3, 1).contiguous().view(-1, num_classes), label_sim.view(-1).to(device)) (loss_g_fake + seg_loss).backward() optim_gen.step() loss_discr_fake_sum += loss_d_fake.item() loss_discr_true_sum += loss_d_true.item() loss_gen_sum += loss_g_fake.item() loss_tot = loss_d_fake.item() + loss_d_true.item() X, data = np.array([batch_count]), np.array([loss_tot]) visualiser.plot(X, data, title='Loss per batch', legend=['Loss'], iteration=2, update='append') batch_since_eval += 1 # DO EVAL torch.cuda.empty_cache() eval_count += 1 results['loss_gen'] = np.divide(loss_gen_sum, batch_since_eval) results['loss_discr_fake'] = np.divide(loss_discr_fake_sum, batch_since_eval) results['loss_discr_true'] = np.divide(loss_discr_true_sum, batch_since_eval) if seg_model_path: results.update(evaluate_segtransfer(model_seg, sim_data.get_valid_iterator(), real_data.get_valid_iterator(), device)) # early_stopper.update(results, epoch_id=eval_count, batch_id=batch_count) log_and_viz_results(results, batch_count, eval_count, visualiser, start, save_dir) if save_dir and batch_count % batch_per_save == 0: # model_gen.save(os.path.join(save_dir, '{}_{}.pth'.format(model_gen.name, batch_count))) model_seg.save(os.path.join(save_dir, '{}_{}.pth'.format(model_seg.name, batch_count))) model_discr.save(os.path.join(save_dir, '{}_{}.pth'.format(model_discr.name, batch_count))) if batch_count >= max_num_batch: print('Stopping training..') break
def main(data_sim_dir, data_real_dir, data_label_dir, save_dir, visdom_dir, batch_size, config_file, discr_model_name, gen_model_name, early_stop_patience, max_num_batch, server, port, reload, run_name, batch_per_eval, batch_per_save, seg_model_path, seg_model_name, content_weight): print('Loading data..') num_classes = 4 input_channels = 3 real_data = InfiniteProviderFromPartitions(input_dir=data_real_dir, label_dir=None, num_workers=0, partition_batch_size=batch_size, partition_num_workers=2) sim_data = InfiniteProviderFromPartitions(input_dir=data_sim_dir, label_dir=data_label_dir, num_workers=0, partition_batch_size=batch_size, partition_num_workers=2) real_data.init_train_iterator() sim_data.init_train_iterator() print('Building model & loading on GPU (if applicable)..') model_gen = models.get_generator_model(gen_model_name, input_channels).to(device) model_discr = models.get_discriminator_model(discr_model_name, input_channels).to(device) if save_dir: model_gen.save( os.path.join(save_dir, '{}_{}.pth'.format(model_gen.name, 0))) model_discr.save( os.path.join(save_dir, '{}_{}.pth'.format(model_discr.name, 0))) if seg_model_path: assert seg_model_name in seg_model_path model_seg = models.get_seg_model(seg_model_name, num_classes, input_channels).to(device) model_seg.load(seg_model_path) obj_adv = nn.BCELoss() content_obj = nn.MSELoss(reduction='none') label_true = 1 label_fake = 0 optim_gen = optim.Adam(model_gen.parameters()) optim_discr = optim.Adam(model_discr.parameters()) print('Initializing misc..') batch_count = 0 results = dict() results['accuracy'] = 0 early_stopper = EarlyStopper('accuracy', early_stop_patience) visualiser = vis.Visualiser(server, port, run_name, reload, visdom_dir) print('Starting training..') for eval_count in count(start=1): batch_since_eval = 0 loss_discr_true_sum = 0 loss_discr_fake_sum = 0 loss_gen_style_sum = 0 loss_gen_content_sum = 0 model_gen.train() model_discr.train() start = timer() while batch_since_eval < batch_per_eval: batch_real = next(real_data).to(device) batch_sim = next(sim_data) batch_sim = batch_sim[0].to(device) assert batch_sim.shape[0] == batch_real.shape[0] b_size = batch_sim.shape[0] batch_count += 1 loss_d = 0 optim_discr.zero_grad() # train discriminator on true data (logD(x)) scores_true = model_discr(batch_sim) labels = torch.full((b_size, ), label_true).to(device) loss_d_true = obj_adv(scores_true.view(-1), labels.to(device)) loss_d_true.backward() # train discriminator on fake data (log(1-D(G(x))) batch_fake = model_gen(batch_real) scores_fake = model_discr(batch_fake.detach()) labels.fill_(label_fake) loss_d_fake = obj_adv(scores_fake.view(-1), labels) loss_d_fake.backward() optim_discr.step() loss_g = 0 optim_gen.zero_grad() # train generator scores_fake = model_discr(batch_fake) labels.fill_(label_true) loss_g_style = obj_adv(scores_fake.view(-1), labels) loss_g_tot = loss_g_style loss_g_content = content_weight * content_obj( batch_fake, batch_real).sum((1, 2, 3)).mean() loss_g_tot += loss_g_content loss_g_tot.backward() optim_gen.step() loss_discr_fake_sum += loss_d_fake.item() loss_discr_true_sum += loss_d_true.item() loss_gen_style_sum += loss_g_style.item() loss_gen_content_sum += loss_g_content.item() loss_tot = loss_d_fake.item() + loss_d_true.item( ) + loss_g_style.item() + loss_g_content.item() X, data = np.array([batch_count]), np.array([loss_tot]) visualiser.plot(X, data, title='Loss per batch', legend=['Loss total'], iteration=2, update='append') batch_since_eval += 1 # DO EVAL torch.cuda.empty_cache() eval_count += 1 results['loss_gen_style'] = np.divide(loss_gen_style_sum, batch_since_eval) results['loss_gen_content'] = np.divide(loss_gen_content_sum, batch_since_eval) results['loss_discr_fake'] = np.divide(loss_discr_fake_sum, batch_since_eval) results['loss_discr_true'] = np.divide(loss_discr_true_sum, batch_since_eval) if seg_model_path: results.update( evaluate_transfer(model_seg, model_gen, real_data.get_valid_iterator(), device)) # early_stopper.update(results, epoch_id=eval_count, batch_id=batch_count) log_and_viz_results(results, batch_count, eval_count, visualiser, start, save_dir) if save_dir and batch_count % batch_per_save == 0: model_gen.save( os.path.join(save_dir, '{}_{}.pth'.format(model_gen.name, batch_count))) model_discr.save( os.path.join(save_dir, '{}_{}.pth'.format(model_discr.name, batch_count))) if batch_count >= max_num_batch: print('Stopping training..') break