def main(): global args args = parser.parse_args() args.batch_size = 1 # only segment one image for experiment model_dir = os.path.dirname(args.dir) core_config_path = os.path.join(model_dir, 'configs/core.config') unet_config_path = os.path.join(model_dir, 'configs/unet.config') core_config = CoreConfig() core_config.read(core_config_path) print('Using core configuration from {}'.format(core_config_path)) # loading Unet configuration unet_config = UnetConfig() unet_config.read(unet_config_path, args.train_image_size) print('Using unet configuration from {}'.format(unet_config_path)) offset_list = core_config.offsets print("offsets are: {}".format(offset_list)) # model configurations from core config num_classes = core_config.num_classes num_colors = core_config.num_colors num_offsets = len(core_config.offsets) # model configurations from unet config start_filters = unet_config.start_filters up_mode = unet_config.up_mode merge_mode = unet_config.merge_mode depth = unet_config.depth model = UNet(num_classes, num_offsets, in_channels=num_colors, depth=depth, start_filts=start_filters, up_mode=up_mode, merge_mode=merge_mode) model_path = os.path.join(model_dir, args.model) if os.path.isfile(model_path): print("=> loading checkpoint '{}'".format(model_path)) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) print("loaded.") else: print("=> no checkpoint found at '{}'".format(model_path)) testset = WaldoTestset(args.test_data, args.train_image_size, job=args.job, num_jobs=args.num_jobs) print('Total samples in the test set: {0}'.format(len(testset))) dataloader = torch.utils.data.DataLoader( testset, num_workers=1, batch_size=args.batch_size) segment_dir = args.dir if not os.path.exists(segment_dir): os.makedirs(segment_dir) segment(dataloader, segment_dir, model, core_config) make_submission(segment_dir, args.csv)
def main(): global args, best_loss args = parser.parse_args() if args.tensorboard: from tensorboard_logger import configure print("Using tensorboard") configure("%s" % (args.dir)) # loading core configuration c_config = CoreConfig() if args.core_config == '': print('No core config file given, using default core configuration') if not os.path.exists(args.core_config): sys.exit('Cannot find the config file: {}'.format(args.core_config)) else: c_config.read(args.core_config) print('Using core configuration from {}'.format(args.core_config)) # loading Unet configuration u_config = UnetConfig() if args.unet_config == '': print('No unet config file given, using default unet configuration') if not os.path.exists(args.unet_config): sys.exit('Cannot find the unet configuration file: {}'.format( args.unet_config)) else: # need train_image_size for validation u_config.read(args.unet_config, args.train_image_size) print('Using unet configuration from {}'.format(args.unet_config)) offset_list = c_config.offsets print("offsets are: {}".format(offset_list)) # model configurations from core config num_classes = c_config.num_classes num_colors = c_config.num_colors num_offsets = len(c_config.offsets) # model configurations from unet config start_filters = u_config.start_filters up_mode = u_config.up_mode merge_mode = u_config.merge_mode depth = u_config.depth train_data = args.train_dir + '/train' val_data = args.train_dir + '/val' trainset = WaldoDataset(train_data, c_config, args.train_image_size) trainloader = torch.utils.data.DataLoader(trainset, num_workers=4, batch_size=args.batch_size, shuffle=True) valset = WaldoDataset(val_data, c_config, args.train_image_size) valloader = torch.utils.data.DataLoader(valset, num_workers=4, batch_size=args.batch_size) NUM_TRAIN = len(trainset) NUM_VAL = len(valset) NUM_ALL = NUM_TRAIN + NUM_VAL print('Total samples: {0} \n' 'Using {1} samples for training, ' '{2} samples for validation'.format(NUM_ALL, NUM_TRAIN, NUM_VAL)) # create model model = UNet(num_classes, num_offsets, in_channels=num_colors, depth=depth, start_filts=start_filters, up_mode=up_mode, merge_mode=merge_mode).cuda() # get the number of model parameters print('Number of model parameters: {}'.format( sum([p.data.nelement() for p in model.parameters()]))) # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_loss = checkpoint['best_loss'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) # define optimizer # optimizer = t.optim.Adam(model.parameters(), lr=1e-3) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, nesterov=args.nesterov, weight_decay=args.weight_decay) # Train for epoch in range(args.start_epoch, args.epochs): Train(trainloader, model, optimizer, epoch) val_loss = Validate(valloader, model, epoch) is_best = val_loss < best_loss best_loss = min(val_loss, best_loss) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_prec1': best_loss, }, is_best) print('Best validation loss: ', best_loss) # visualize some example outputs outdir = '{}/imgs'.format(args.dir) if not os.path.exists(outdir): os.makedirs(outdir) sample(model, valloader, outdir, c_config)
def main(): global args args = parser.parse_args() args.batch_size = 1 # only segment one image for experiment core_config_path = os.path.join(args.dir, 'configs/core.config') unet_config_path = os.path.join(args.dir, 'configs/unet.config') core_config = CoreConfig() core_config.read(core_config_path) print('Using core configuration from {}'.format(core_config_path)) # loading Unet configuration unet_config = UnetConfig() unet_config.read(unet_config_path, args.train_image_size) print('Using unet configuration from {}'.format(unet_config_path)) offset_list = core_config.offsets print("offsets are: {}".format(offset_list)) # model configurations from core config num_classes = core_config.num_classes num_colors = core_config.num_colors num_offsets = len(core_config.offsets) # model configurations from unet config start_filters = unet_config.start_filters up_mode = unet_config.up_mode merge_mode = unet_config.merge_mode depth = unet_config.depth model = UNet(num_classes, num_offsets, in_channels=num_colors, depth=depth, start_filts=start_filters, up_mode=up_mode, merge_mode=merge_mode) model_path = os.path.join(args.dir, args.model) if os.path.isfile(model_path): print("=> loading checkpoint '{}'".format(model_path)) checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage) model.load_state_dict(checkpoint['state_dict']) print("loaded.") else: print("=> no checkpoint found at '{}'".format(model_path)) model.eval() # convert the model into evaluation mode testset = WaldoDataset(args.test_data, core_config, args.train_image_size) print('Total samples in the test set: {0}'.format(len(testset))) dataloader = torch.utils.data.DataLoader(testset, num_workers=1, batch_size=args.batch_size) segment_dir = '{}/segment'.format(args.dir) if not os.path.exists(segment_dir): os.makedirs(segment_dir) img, class_pred, adj_pred = sample(model, dataloader, segment_dir, core_config) seg = ObjectSegmenter(class_pred[0].detach().numpy(), adj_pred[0].detach().numpy(), num_classes, offset_list) mask_pred, object_class = seg.run_segmentation() x = {} # from (color, height, width) to (height, width, color) x['img'] = np.moveaxis(img[0].numpy(), 0, -1) x['mask'] = mask_pred.astype(int) x['object_class'] = object_class visualize_mask(x, core_config)