Ejemplo n.º 1
0
def main(eval_args):
    # ensures that weight initializations are all the same
    logging = utils.Logger(eval_args.local_rank, eval_args.save)

    # load a checkpoint
    logging.info('loading the model at:')
    logging.info(eval_args.checkpoint)
    checkpoint = torch.load(eval_args.checkpoint, map_location='cpu')
    args = checkpoint['args']

    if not hasattr(args, 'ada_groups'):
        logging.info('old model, no ada groups was found.')
        args.ada_groups = False

    if not hasattr(args, 'min_groups_per_scale'):
        logging.info('old model, no min_groups_per_scale was found.')
        args.min_groups_per_scale = 1

    if not hasattr(args, 'num_mixture_dec'):
        logging.info('old model, no num_mixture_dec was found.')
        args.num_mixture_dec = 10

    logging.info('loaded the model at epoch %d', checkpoint['epoch'])
    arch_instance = utils.get_arch_cells(args.arch_instance)
    model = AutoEncoder(args, None, arch_instance)
    # Loading is not strict because of self.weight_normalized in Conv2D class in neural_operations. This variable
    # is only used for computing the spectral normalization and it is safe not to load it. Some of our earlier models
    # did not have this variable.
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model = model.cuda()

    logging.info('args = %s', args)
    logging.info('num conv layers: %d', len(model.all_conv_layers))
    logging.info('param size = %fM ', utils.count_parameters_in_M(model))

    if eval_args.eval_mode == 'evaluate':
        # load train valid queue
        args.data = eval_args.data
        train_queue, valid_queue, num_classes = datasets.get_loaders(args)

        if eval_args.eval_on_train:
            logging.info('Using the training data for eval.')
            valid_queue = train_queue

        # get number of bits
        num_output = utils.num_output(args.dataset)
        bpd_coeff = 1. / np.log(2.) / num_output

        valid_neg_log_p, valid_nelbo = test(
            valid_queue,
            model,
            num_samples=eval_args.num_iw_samples,
            args=args,
            logging=logging)
        logging.info('final valid nelbo %f', valid_nelbo)
        logging.info('final valid neg log p %f', valid_neg_log_p)
        logging.info('final valid nelbo in bpd %f', valid_nelbo * bpd_coeff)
        logging.info('final valid neg log p in bpd %f',
                     valid_neg_log_p * bpd_coeff)

    else:
        bn_eval_mode = not eval_args.readjust_bn
        num_samples = 16
        with torch.no_grad():
            n = int(np.floor(np.sqrt(num_samples)))
            set_bn(model,
                   bn_eval_mode,
                   num_samples=36,
                   t=eval_args.temp,
                   iter=500)
            for ind in range(10):  # sampling is repeated.
                torch.cuda.synchronize()
                start = time()
                with autocast():
                    logits = model.sample(num_samples, eval_args.temp)
                output = model.decoder_output(logits)
                output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) \
                    else output.sample()
                torch.cuda.synchronize()
                end = time()

                output_tiled = utils.tile_image(output_img,
                                                n).cpu().numpy().transpose(
                                                    1, 2, 0)
                logging.info('sampling time per batch: %0.3f sec',
                             (end - start))
                output_tiled = np.asarray(output_tiled * 255, dtype=np.uint8)
                output_tiled = np.squeeze(output_tiled)

                plt.imshow(output_tiled)
                plt.show()
Ejemplo n.º 2
0
def main(args):
    # ensures that weight initializations are all the same
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    logging = utils.Logger(args.global_rank, args.save)
    writer = utils.Writer(args.global_rank, args.save)

    # Get data loaders.
    train_queue, valid_queue, num_classes, _ = datasets.get_loaders(args)
    args.num_total_iter = len(train_queue) * args.epochs
    warmup_iters = len(train_queue) * args.warmup_epochs
    swa_start = len(train_queue) * (args.epochs - 1)

    arch_instance = utils.get_arch_cells(args.arch_instance)

    model = AutoEncoder(args, writer, arch_instance)
    model = model.cuda()

    logging.info('args = %s', args)
    logging.info('param size = %fM ', utils.count_parameters_in_M(model))
    logging.info('groups per scale: %s, total_groups: %d',
                 model.groups_per_scale, sum(model.groups_per_scale))

    if args.fast_adamax:
        # Fast adamax has the same functionality as torch.optim.Adamax, except it is faster.
        cnn_optimizer = Adamax(model.parameters(),
                               args.learning_rate,
                               weight_decay=args.weight_decay,
                               eps=1e-3)
    else:
        cnn_optimizer = torch.optim.Adamax(model.parameters(),
                                           args.learning_rate,
                                           weight_decay=args.weight_decay,
                                           eps=1e-3)

    cnn_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        cnn_optimizer,
        float(args.epochs - args.warmup_epochs - 1),
        eta_min=args.learning_rate_min)
    grad_scalar = GradScaler(2**10)

    num_output = utils.num_output(args.dataset, args)
    bpd_coeff = 1. / np.log(2.) / num_output

    # if load
    checkpoint_file = os.path.join(args.save, 'checkpoint.pt')
    if args.cont_training:
        logging.info('loading the model.')
        checkpoint = torch.load(checkpoint_file, map_location='cpu')
        init_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        model = model.cuda()
        cnn_optimizer.load_state_dict(checkpoint['optimizer'])
        grad_scalar.load_state_dict(checkpoint['grad_scalar'])
        cnn_scheduler.load_state_dict(checkpoint['scheduler'])
        global_step = checkpoint['global_step']
    else:
        global_step, init_epoch = 0, 0

    for epoch in range(init_epoch, args.epochs):
        # update lrs.
        if args.distributed:
            train_queue.sampler.set_epoch(global_step + args.seed)
            valid_queue.sampler.set_epoch(0)

        if epoch > args.warmup_epochs:
            cnn_scheduler.step()

        # Logging.
        logging.info('epoch %d', epoch)

        # Training.
        train_nelbo, global_step = train(train_queue, model, cnn_optimizer,
                                         grad_scalar, global_step,
                                         warmup_iters, writer, logging)
        logging.info('train_nelbo %f', train_nelbo)
        writer.add_scalar('train/nelbo', train_nelbo, global_step)

        model.eval()
        # generate samples less frequently
        eval_freq = 1 if args.epochs <= 50 else 20
        if epoch % eval_freq == 0 or epoch == (args.epochs - 1):
            with torch.no_grad():
                num_samples = 16
                n = int(np.floor(np.sqrt(num_samples)))
                for t in [0.7, 0.8, 0.9, 1.0]:
                    logits = model.sample(num_samples, t)
                    output = model.decoder_output(logits)
                    output_img = output.mean if isinstance(
                        output, torch.distributions.bernoulli.Bernoulli
                    ) else output.sample(t)
                    output_tiled = utils.tile_image(output_img, n)
                    writer.add_image('generated_%0.1f' % t, output_tiled,
                                     global_step)

            valid_neg_log_p, valid_nelbo = test(valid_queue,
                                                model,
                                                num_samples=10,
                                                args=args,
                                                logging=logging)
            logging.info('valid_nelbo %f', valid_nelbo)
            logging.info('valid neg log p %f', valid_neg_log_p)
            logging.info('valid bpd elbo %f', valid_nelbo * bpd_coeff)
            logging.info('valid bpd log p %f', valid_neg_log_p * bpd_coeff)
            writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch)
            writer.add_scalar('val/nelbo', valid_nelbo, epoch)
            writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff,
                              epoch)
            writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch)

        save_freq = int(np.ceil(args.epochs / 100))
        if epoch % save_freq == 0 or epoch == (args.epochs - 1):
            if args.global_rank == 0:
                logging.info('saving the model.')
                torch.save(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': cnn_optimizer.state_dict(),
                        'global_step': global_step,
                        'args': args,
                        'arch_instance': arch_instance,
                        'scheduler': cnn_scheduler.state_dict(),
                        'grad_scalar': grad_scalar.state_dict()
                    }, checkpoint_file)

    # Final validation
    valid_neg_log_p, valid_nelbo = test(valid_queue,
                                        model,
                                        num_samples=1000,
                                        args=args,
                                        logging=logging)
    logging.info('final valid nelbo %f', valid_nelbo)
    logging.info('final valid neg log p %f', valid_neg_log_p)
    writer.add_scalar('val/neg_log_p', valid_neg_log_p, epoch + 1)
    writer.add_scalar('val/nelbo', valid_nelbo, epoch + 1)
    writer.add_scalar('val/bpd_log_p', valid_neg_log_p * bpd_coeff, epoch + 1)
    writer.add_scalar('val/bpd_elbo', valid_nelbo * bpd_coeff, epoch + 1)
    writer.close()
Ejemplo n.º 3
0
def train(train_queue, model, cnn_optimizer, grad_scalar, global_step,
          warmup_iters, writer, logging):
    alpha_i = utils.kl_balancer_coeff(num_scales=model.num_latent_scales,
                                      groups_per_scale=model.groups_per_scale,
                                      fun='square')
    nelbo = utils.AvgrageMeter()
    model.train()
    for step, x in enumerate(train_queue):
        x = x[0] if len(x) > 1 else x
        x = x.half().cuda()

        # change bit length
        x = utils.pre_process(x, args.num_x_bits)

        # warm-up lr
        if global_step < warmup_iters:
            lr = args.learning_rate * float(global_step) / warmup_iters
            for param_group in cnn_optimizer.param_groups:
                param_group['lr'] = lr

        # sync parameters, it may not be necessary
        if step % 100 == 0:
            utils.average_params(model.parameters(), args.distributed)

        cnn_optimizer.zero_grad()
        with autocast():
            logits, log_q, log_p, kl_all, kl_diag = model(x)

            output = model.decoder_output(logits)
            kl_coeff = utils.kl_coeff(
                global_step, args.kl_anneal_portion * args.num_total_iter,
                args.kl_const_portion * args.num_total_iter,
                args.kl_const_coeff)

            recon_loss = utils.reconstruction_loss(output,
                                                   x,
                                                   crop=model.crop_output)
            balanced_kl, kl_coeffs, kl_vals = utils.kl_balancer(
                kl_all, kl_coeff, kl_balance=True, alpha_i=alpha_i)

            nelbo_batch = recon_loss + balanced_kl
            loss = torch.mean(nelbo_batch)
            norm_loss = model.spectral_norm_parallel()
            bn_loss = model.batchnorm_loss()
            # get spectral regularization coefficient (lambda)
            if args.weight_decay_norm_anneal:
                assert args.weight_decay_norm_init > 0 and args.weight_decay_norm > 0, 'init and final wdn should be positive.'
                wdn_coeff = (1. - kl_coeff) * np.log(
                    args.weight_decay_norm_init) + kl_coeff * np.log(
                        args.weight_decay_norm)
                wdn_coeff = np.exp(wdn_coeff)
            else:
                wdn_coeff = args.weight_decay_norm

            loss += norm_loss * wdn_coeff + bn_loss * wdn_coeff

        grad_scalar.scale(loss).backward()
        utils.average_gradients(model.parameters(), args.distributed)
        grad_scalar.step(cnn_optimizer)
        grad_scalar.update()
        nelbo.update(loss.data, 1)

        if (global_step + 1) % 100 == 0:
            if (global_step + 1) % 1000 == 0:  # reduced frequency
                n = int(np.floor(np.sqrt(x.size(0))))
                x_img = x[:n * n]
                output_img = output.mean if isinstance(
                    output, torch.distributions.bernoulli.Bernoulli
                ) else output.sample()
                output_img = output_img[:n * n]
                x_tiled = utils.tile_image(x_img, n)
                output_tiled = utils.tile_image(output_img, n)
                in_out_tiled = torch.cat((x_tiled, output_tiled), dim=2)
                writer.add_image('reconstruction', in_out_tiled, global_step)

            # norm
            writer.add_scalar('train/norm_loss', norm_loss, global_step)
            writer.add_scalar('train/bn_loss', bn_loss, global_step)
            writer.add_scalar('train/norm_coeff', wdn_coeff, global_step)

            utils.average_tensor(nelbo.avg, args.distributed)
            logging.info('train %d %f', global_step, nelbo.avg)
            writer.add_scalar('train/nelbo_avg', nelbo.avg, global_step)
            writer.add_scalar(
                'train/lr',
                cnn_optimizer.state_dict()['param_groups'][0]['lr'],
                global_step)
            writer.add_scalar('train/nelbo_iter', loss, global_step)
            writer.add_scalar('train/kl_iter', torch.mean(sum(kl_all)),
                              global_step)
            writer.add_scalar(
                'train/recon_iter',
                torch.mean(
                    utils.reconstruction_loss(output,
                                              x,
                                              crop=model.crop_output)),
                global_step)
            writer.add_scalar('kl_coeff/coeff', kl_coeff, global_step)
            total_active = 0
            for i, kl_diag_i in enumerate(kl_diag):
                utils.average_tensor(kl_diag_i, args.distributed)
                num_active = torch.sum(kl_diag_i > 0.1).detach()
                total_active += num_active

                # kl_ceoff
                writer.add_scalar('kl/active_%d' % i, num_active, global_step)
                writer.add_scalar('kl_coeff/layer_%d' % i, kl_coeffs[i],
                                  global_step)
                writer.add_scalar('kl_vals/layer_%d' % i, kl_vals[i],
                                  global_step)
            writer.add_scalar('kl/total_active', total_active, global_step)

        global_step += 1

    utils.average_tensor(nelbo.avg, args.distributed)
    return nelbo.avg, global_step
Ejemplo n.º 4
0
def run(predict_folder,output_folder, working_dir=constants.working_dir, batch_size = constants.batch_size):

    
    [temp_dir, tiles_dir, pred_tiles_dir] = make_folders(working_dir)
    classes = utils.load_obj(constants.label_map)
    
    utils.create_color_legend(classes, os.path.join(output_folder,"color_legend.png"))

    print("Preparing image tiles...",flush=True)
    images_to_predict = utils.get_all_image_paths_in_folder(predict_folder)
    for image_path in progressbar.progressbar(images_to_predict):
        utils.tile_image(image_path, tiles_dir, classes, tile_size=256, overlap=0)
    
    print("Loading Trained Model...",flush=True)
    
    all_tile_paths = utils.get_all_image_paths_in_folder(tiles_dir)

    #test_frames_dir = "C:/Users/johan/Desktop/working_dir/training_data/train_frames/0"
    #test_masks_dir = "C:/Users/johan/Desktop/working_dir/training_data/train_masks/0"
    
    
    
    #print(len(classes))
    
    model = unet_utils.get_small_unet(n_filters = 32,num_classes=len(classes),batch_size=batch_size)
    model.compile(optimizer='adam', loss="categorical_crossentropy", metrics=[unet_utils.tversky_loss,unet_utils.dice_coef,'accuracy'])
    #model.summary()
    
    model_save_path = constants.trained_model
    
    model.load_weights(model_save_path)
    data_generator = unet_utils.DataGeneratorWithFileNames(tiles_dir,classes,batch_size=batch_size)
    
    num_batches = int(np.ceil(float(len(all_tile_paths)) / float(batch_size)))
    tile_index = 0
    
    print("Predicting Tiles...",flush=True)
    for batch_number in progressbar.progressbar(range(0,num_batches)):
        img_batch,filenames=next(data_generator)
        pred_all= model.predict(img_batch)
        #pred = model.predict_generator(data_generator,steps=64)
        #reassemble(image_path,mask_path,pred,classes)
        for i in range(0,np.shape(pred_all)[0]):
            if(tile_index == len(all_tile_paths)):
                break
            predicted_mask = utils.onehot_to_rgb(pred_all[i],classes)
            img = img_batch[i] *255
            mask_pred_path = os.path.join(pred_tiles_dir,os.path.basename(filenames[i]))
            img_path = os.path.join(pred_tiles_dir,os.path.basename(filenames[i].replace(".png","a.png")))

            save_array_as_image(mask_pred_path, predicted_mask)
            save_array_as_image(img_path, img)
            tile_index += 1
            #mask = onehot_to_rgb(batch_mask[i],classes)
            #img = batch_img[i] *255
            
            #img_path = os.path.join("C:/Users/johan/Desktop/test", str(j) + "_" + str(i) + "_img.png")
            #mask_path = os.path.join("C:/Users/johan/Desktop/test", str(j) + "_" + str(i) + "_mask.png")

            #save_array_as_image(mask_path,mask)
            #save_array_as_image(img_path, img)


    
    print("Reassembling Tiles...",flush=True)
    for image_path in progressbar.progressbar(images_to_predict):
        dst_image_path = os.path.join(output_folder, os.path.basename(image_path))
        reassemble(image_path,dst_image_path,pred_tiles_dir)
        
    print("Cleaning up...",flush=True)
    utils.delete_folder_contents(temp_dir)
Ejemplo n.º 5
0
def run(src_dirs=constants.data_source_folders, working_dir=constants.working_dir):
    
    for src_dir_index,src_dir in enumerate(src_dirs):
        shape_file_path = os.path.join(src_dir,"shapes/shapes.shp")
        if os.path.isfile(shape_file_path):
            add_shapefile_classes_to_label_dictionary(shape_file_path)
        else:
            add_labelme_classes_to_label_dictionary(src_dir)
            
    print(str(len(classes)) + " classes present in dataset:")
    print(classes)
    
    utils.save_obj(classes,os.path.join(working_dir,"labelmap.pkl"))
    
    utils.create_color_legend(classes, os.path.join(working_dir,"color_legend.png"))

    
    (temp_dir,mask_tiles_dir,image_tiles_dir) = make_folders(working_dir)

    
    for src_dir_index,src_dir in enumerate(src_dirs):
        
        shape_file_path = os.path.join(src_dir,"shapes/shapes.shp")
        shape_file_mode = False
        if os.path.isfile(shape_file_path):
            shape_file_mode = True
        
        if shape_file_mode:
            images_folder = os.path.join(src_dir,"images")
        else:
            images_folder = src_dir
        print("Generating masks for all images in input folder: " + src_dir,flush=True)
        
        for image_path in progressbar.progressbar(utils.get_all_image_paths_in_folder(images_folder)):
            
            if shape_file_mode:
                projected_image_path = os.path.join(temp_dir,os.path.basename(image_path).replace(".tif","_srcdir" + str(src_dir_index) + ".tif"))
                utils.resize_image_and_change_coordinate_system(image_path,projected_image_path)
                image_path = projected_image_path
                all_polygons = get_all_polygons_from_shapefile(shape_file_path)
                all_polygons = convert_polygon_coords_to_pixel_coords(all_polygons,image_path)   

            else:
                resized_image_path = os.path.join(temp_dir, os.path.basename(image_path)[:-4]+"_srcdir" + str(src_dir_index) + ".tif")
                labelme_file_path = image_path[:-4] + ".json"
                all_polygons = get_all_polygons_from_labelme_file(labelme_file_path)
                resize_image_and_polygons(src_dir,image_path,all_polygons,resized_image_path)
                
                image_path = resized_image_path
            
                
                
            mask_image_path = os.path.join(temp_dir,os.path.basename(image_path)[:-4]+"_mask.tif")
            make_mask_image(image_path,mask_image_path,all_polygons)
            
            utils.tile_image(mask_image_path,mask_tiles_dir,classes, src_dir_index=src_dir_index, is_mask= True)
            utils.tile_image(image_path,image_tiles_dir,classes, src_dir_index=src_dir_index)

        
        if constants.only_use_area_within_shapefile_polygons[src_dir_index]:
            for mask_tile,image_tile in zip(utils.get_all_image_paths_in_folder(mask_tiles_dir),utils.get_all_image_paths_in_folder(image_tiles_dir)):
                if "srcdir" + str(src_dir_index) in mask_tile:
                    image = Image.open(mask_tile)
                    colors = image.getcolors(256*256)
                    for color in colors:
                        if utils.name2color(classes,"Background") == color[1]:
                            os.remove(mask_tile)
                            os.remove(image_tile)
            
        #utils.delete_folder_contents(temp_dir)
            
            
            
    #shutil.rmtree(temp_dir)
    print("Splitting all training tiles and masks into train, validation and test sets...",flush=True)
    split_dataset.split_into_train_val_and_test_sets(working_dir)
Ejemplo n.º 6
0
def main(eval_args):
    # ensures that weight initializations are all the same
    logging = utils.Logger(eval_args.local_rank, eval_args.save)

    # load a checkpoint
    logging.info('loading the model at:')
    logging.info(eval_args.checkpoint)
    checkpoint = torch.load(eval_args.checkpoint, map_location='cpu')
    args = checkpoint['args']

    logging.info('loaded the model at epoch %d', checkpoint['epoch'])
    arch_instance = utils.get_arch_cells(args.arch_instance)
    model = AutoEncoder(args, None, arch_instance)
    model.load_state_dict(checkpoint['state_dict'])
    model = model.cuda()

    logging.info('args = %s', args)
    logging.info('num conv layers: %d', len(model.all_conv_layers))
    logging.info('param size = %fM ', utils.count_parameters_in_M(model))

    if eval_args.eval_mode == 'evaluate':
        # load train valid queue
        args.data = eval_args.data
        train_queue, valid_queue, num_classes, test_queue = datasets.get_loaders(args)

        if eval_args.eval_on_train:
            logging.info('Using the training data for eval.')
            valid_queue = train_queue
        if eval_args.eval_on_test:
            logging.info('Using the test data for eval.')
            valid_queue = test_queue

        # get number of bits
        num_output = utils.num_output(args.dataset, args)
        bpd_coeff = 1. / np.log(2.) / num_output

        valid_neg_log_p, valid_nelbo = test(valid_queue, model, num_samples=eval_args.num_iw_samples, args=args, logging=logging)
        logging.info('final valid nelbo %f', valid_nelbo)
        logging.info('final valid neg log p %f', valid_neg_log_p)
        logging.info('final valid nelbo in bpd %f', valid_nelbo * bpd_coeff)
        logging.info('final valid neg log p in bpd %f', valid_neg_log_p * bpd_coeff)

    else:
        bn_eval_mode = not eval_args.readjust_bn
        num_samples = 16
        with torch.no_grad():
            n = int(np.floor(np.sqrt(num_samples)))
            set_bn(model, bn_eval_mode, num_samples=36, t=eval_args.temp, iter=500)
            for ind in range(eval_args.repetition):     # sampling is repeated.
                torch.cuda.synchronize()
                start = time()
                with autocast():
                    logits = model.sample(num_samples, eval_args.temp)
                output = model.decoder_output(logits)
                output_img = output.mean if isinstance(output, torch.distributions.bernoulli.Bernoulli) \
                    else output.sample()
                torch.cuda.synchronize()
                end = time()

                # save to file
                total_name = "{}/data_to_save_{}_{}.pickle".format(eval_args.save, eval_args.name_to_save, ind)
                with open(total_name, 'wb') as handle:
                    pickle.dump(output_img.deatach().numpy(), handle, protocol=pickle.HIGHEST_PROTOCOL)

                output_tiled = utils.tile_image(output_img, n).cpu().numpy().transpose(1, 2, 0)
                logging.info('sampling time per batch: %0.3f sec', (end - start))
                output_tiled = np.asarray(output_tiled * 255, dtype=np.uint8)
                output_tiled = np.squeeze(output_tiled)

                plt.imshow(output_tiled)
                plt.savefig("{}/generation_{}_{}".format(eval_args.save, eval_args.name_to_save, ind))