def main(args): dataset_configs = DatasetParams(args.dataset_config_file) dataset_params_1 = dataset_configs.get_params(args.dataset_1) dataset_params_2 = dataset_configs.get_params(args.dataset_2) for p in (dataset_params_1, dataset_params_2): for d in p: # Tell dataset to output id in integer or other simple format: d.config_dict['return_simple_image_id'] = True data_loader_1, _ = get_loader(dataset_params_1, vocab=None, transform=None, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ext_feature_sets=None, skip_images=True, iter_over_images=True) data_loader_2, _ = get_loader(dataset_params_2, vocab=None, transform=None, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ext_feature_sets=None, skip_images=True, iter_over_images=True) show_progress = sys.stderr.isatty() print("Reading image ids from dataset {}".format(args.dataset_1)) ids_1 = [ img_ids for _, _, _, img_ids, _ in tqdm(data_loader_1, disable=not show_progress) ] set_1 = set(chain(*ids_1)) print("Reading image ids from dataset {}".format(args.dataset_2)) ids_2 = [ img_ids for _, _, _, img_ids, _ in tqdm(data_loader_2, disable=not show_progress) ] set_2 = set(chain(*ids_2)) intersection = set_1.intersection(set_2) len_intersect = len(intersection) print("There are {} images shared between {} and {}".format( len_intersect, args.dataset_1, args.dataset_2))
def get_feature_dims(state, args): dataset_configs = DatasetParams(args.dataset_config_file) dataset_params = dataset_configs.get_params(args.dataset)[0] features_paths = dataset_params.features_path ext_feature_sets = [ state['features'].external, state['persist_features'].external ] loaders_and_dims = [ ExternalFeature.loaders(fs, features_paths) for fs in ext_feature_sets ] loaders, dims = zip(*loaders_and_dims) return dims
def check_dataset(args): if args.dataset is None: print('ERROR: No dataset selected!') print( 'Please supply a training dataset with the argument --dataset DATASET' ) print('The following datasets are configured in {}:'.format( args.dataset_config_file)) dataset_configs = DatasetParams(args.dataset_config_file) for ds, _ in dataset_configs.config.items(): if ds not in ('DEFAULT', 'generic'): print(' ', ds) sys.exit(1) return args.dataset
def infer(self, args): # print('infer() :', args) if 'image_features' not in args: args['image_features'] = None # Image preprocessing transform = transforms.Compose([ transforms.Resize((args['resize'], args['resize'])), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) # Get dataset parameters: dataset_configs = DatasetParams(args['dataset_config_file']) dataset_params = dataset_configs.get_params(args['dataset'], args['image_dir'], args['image_files'], args['image_features']) if self.params.has_external_features() and \ any(dc.name == 'generic' for dc in dataset_params): print('WARNING: you cannot use external features without specifying all datasets in ' 'datasets.conf.') print('Hint: take a look at datasets/datasets.conf.default.') # Build data loader print("Loading dataset: {}".format(args['dataset'])) # Update dataset params with needed model params: for i in dataset_params: i.config_dict['skip_start_token'] = self.params.skip_start_token # For visualizing attention we need file names instead of IDs in our output: if args['store_image_paths']: i.config_dict['return_image_file_name'] = True ext_feature_sets = [self.params.features.external, self.params.persist_features.external] if args['dataset']=='incore': ext_feature_sets = None # We ask it to iterate over images instead of all (image, caption) pairs data_loader, ef_dims = get_loader(dataset_params, vocab=None, transform=transform, batch_size=args['batch_size'], shuffle=False, num_workers=args['num_workers'], ext_feature_sets=ext_feature_sets, skip_images=not self.params.has_internal_features(), iter_over_images=True) self.data_loader = data_loader # Create model directory if not os.path.exists(args['results_path']): os.makedirs(args['results_path']) scorers = {} if args['scoring'] is not None: for s in args['scoring'].split(','): s = s.lower().strip() if s == 'cider': from eval.cider import Cider scorers['CIDEr'] = Cider(df='corpus') # Store captions here: output_data = [] gts = {} res = {} print('Starting inference, max sentence length: {} num_workers: {}'.\ format(args['max_seq_length'], args['num_workers'])) show_progress = sys.stderr.isatty() and not args['verbose'] \ and ext_feature_sets is not None for i, (images, ref_captions, lengths, image_ids, features) in enumerate(tqdm(self.data_loader, disable=not show_progress)): if len(scorers) > 0: for j in range(len(ref_captions)): jid = image_ids[j] if jid not in gts: gts[jid] = [] rcs = ref_captions[j] if type(rcs) is str: rcs = [rcs] for rc in rcs: gts[jid].append(rc.lower()) images = images.to(device) init_features = features[0].to(device) if len(features) > 0 and \ features[0] is not None else None persist_features = features[1].to(device) if len(features) > 1 and \ features[1] is not None else None # Generate a caption from the image sampled_batch = self.model.sample(images, init_features, persist_features, max_seq_length=args['max_seq_length'], start_token_id=self.vocab('<start>'), end_token_id=self.vocab('<end>'), alternatives=args['alternatives'], probabilities=args['probabilities']) sampled_ids_batch = sampled_batch for i in range(len(sampled_ids_batch)): sampled_ids = sampled_ids_batch[i] # Convert word_ids to words if self.params.hierarchical_model: # assert False, 'paragraph_ids_to_words() need to be updated' caption = paragraph_ids_to_words(sampled_ids, self.vocab, skip_start_token=True) else: caption = caption_ids_ext_to_words(sampled_ids, self.vocab, skip_start_token=True, capitalize=not args['no_capitalize']) if args['no_repeat_sentences']: caption = remove_duplicate_sentences(caption) if args['only_complete_sentences']: caption = remove_incomplete_sentences(caption) if args['verbose']: print('=>', caption) if True: caption = self.apply_lemma_pos_rules(caption) if args['verbose']: print('#>', caption) output_data.append({'image_id': image_ids[i], 'caption': caption}) res[image_ids[i]] = [caption.lower()] for score_name, scorer in scorers.items(): score = scorer.compute_score(gts, res)[0] print('Test', score_name, score) # Decide output format, fall back to txt if args['output_format'] is not None: output_format = args['output_format'] elif args['output_file'] and args['output_file'].endswith('.json'): output_format = 'json' else: output_format = 'txt' # Create a sensible default output path for results: output_file = None if not args['output_file'] and not args['print_results']: model_name_path = Path(args['model']) is_in_same_folder = len(model_name_path.parents) == 1 if not is_in_same_folder: model_name = args['model'].split(os.sep)[-2] model_epoch = basename(args['model']) output_file = '{}-{}.{}'.format(model_name, model_epoch, output_format) else: output_file = model_name_path.stem + '.' + output_format else: output_file = args['output_file'] if output_file: output_path = os.path.join(args['results_path'], output_file) if output_format == 'json': json.dump(output_data, open(output_path, 'w')) else: with open(output_path, 'w') as fp: for data in output_data: print(data['image_id'], data['caption'], file=fp) print('Wrote generated captions to {} as {}'. format(output_path, args['output_format'])) if args['print_results']: for d in output_data: print('{}: {}'.format(d['image_id'], d['caption'])) return output_data
def main(args): # # Image preprocessing if args.feature_type == 'plain': if args.extractor == 'resnet152caffe-original': # Use custom transform: transform = transforms.Compose([ transforms.Resize((args.crop_size, args.crop_size)), # Swap color space from RGB to BGR and subtract caffe-specific # channel values from each pixel transforms.Lambda(lambda img: np.array(img, dtype=np.float32)[ ..., [2, 1, 0]] - [103.939, 116.779, 123.68]), # Create a torch tensor and put channels first: transforms.Lambda( lambda img: torch.from_numpy(img).permute(2, 0, 1)), # Cast tensor to correct type: transforms.Lambda(lambda img: img.type('torch.FloatTensor')) ]) else: # Default transform transform = transforms.Compose([ transforms.Resize((args.crop_size, args.crop_size)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) elif args.feature_type == 'avg' or args.feature_type == 'max': # Try with no normalization # Try with subtracting 0.5 from all values # See example here: https://pytorch.org/docs/stable/torchvision/transforms.html if args.normalize == 'default': transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), # 10-crop implementation as described in PyTorch documentation: transforms.TenCrop((args.crop_size, args.crop_size)), # Apply next two transforms to each crop in turn and then stack them # to a single tensor: transforms.Lambda(lambda crops: torch.stack([ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) (transforms.ToTensor()(crop)) for crop in crops ])) ]) elif args.normalize == 'skip': transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.TenCrop((args.crop_size, args.crop_size)), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])) ]) elif args.normalize == 'subtract_half': transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.TenCrop((args.crop_size, args.crop_size)), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops]) - 0.5) ]) else: print("Invalid normalization parameter") sys.exit(1) else: print("Invalid feature type specified {}".args.feature_type) sys.exit(1) print("Creating features of type: {}".format(args.feature_type)) # Get dataset parameters and vocabulary wrapper: dataset_configs = DatasetParams(args.dataset_config_file) dataset_params = dataset_configs.get_params(args.dataset) # We want to only get the image file name, not the full path: for i in dataset_params: i.config_dict['return_image_file_name'] = True # We ask it to iterate over images instead of all (image, caption) pairs data_loader, _ = get_loader(dataset_params, vocab=None, transform=transform, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ext_feature_sets=None, skip_images=False, iter_over_images=True) extractor = FeatureExtractor(args.extractor, True).to(device).eval() # To open an lmdb handle and prepare it for the right size # it needs to fit the total number of elements in the dataset # so we set a map_size to a largish value here: map_size = 1e12 lmdb_path = None file_name = None if args.output_file: file_name = args.output_file else: file_name = '{}-{}-{}-normalize-{}.lmdb'.format( args.dataset, args.extractor, args.feature_type, args.normalize) os.makedirs(args.output_dir, exist_ok=True) lmdb_path = os.path.join(args.output_dir, file_name) # Check that we are not overwriting anything if os.path.exists(lmdb_path): print( 'ERROR: {} exists, please remove it first if you really want to replace it.' .format(lmdb_path)) sys.exit(1) print("Preparing to store extracted features to {}...".format(lmdb_path)) print("Starting to extract features from dataset {} using {}...".format( args.dataset, args.extractor)) show_progress = sys.stderr.isatty() # If feature shape is not 1-dimensional, store feature shape metadata: if isinstance(extractor.output_dim, np.ndarray): with lmdb.open(lmdb_path, map_size=map_size) as env: with env.begin(write=True) as txn: txn.put(str('@vdim').encode('ascii'), extractor.output_dim) for i, (images, _, _, image_ids, _) in enumerate(tqdm(data_loader, disable=not show_progress)): images = images.to(device) # If we are dealing with cropped images, image dimensions are: bs, ncrops, c, h, w if images.dim() == 5: bs, ncrops, c, h, w = images.size() # fuse batch size and ncrops: raw_features = extractor(images.view(-1, c, h, w)) if args.feature_type == 'avg': # Average over crops: features = raw_features.view(bs, ncrops, -1).mean(1).data.cpu().numpy() elif args.feature_type == 'max': # Max over crops: features = raw_features.view(bs, ncrops, -1).max(1)[0].data.cpu().numpy() # Otherwise our image dimensions are bs, c, h, w else: features = extractor(images).data.cpu().numpy() # Write to LMDB object: with lmdb.open(lmdb_path, map_size=map_size) as env: with env.begin(write=True) as txn: for j, image_id in enumerate(image_ids): # If output dimension is not a scalar, flatten the array. # When retrieving this feature from the LMDB, developer must take # care to reshape the feature back to the correct dimensions! if isinstance(extractor.output_dim, np.ndarray): _feature = features[j].flatten() # Otherwise treat it as is: else: _feature = features[j] txn.put(str(image_id).encode('ascii'), _feature) # Print log info if not show_progress and ((i + 1) % args.log_step == 0): print('Batch [{}/{}]'.format(i + 1, len(data_loader))) sys.stdout.flush()
def main(args): os.environ["CUDA_VISIBLE_DEVICES"] = \ ','.join(str(gpu) for gpu in args.visible_gpus) device = "cuda" if torch.cuda.is_available() else "cpu" dataset = args.dataset net_name = args.cnn_name root_path = '.' data_path = os.path.join(root_path, "data") save_path = os.path.join(root_path, "results", "iDLG_%s_%s" % (dataset, net_name)) if args.add_clamp: save_path += "_clamp" # Some running setting initial_lr = args.lr num_dummy = 1 Iteration = args.max_iter plot_steps = args.plot_steps # run_methods = ["iDLG", "DLG"] run_methods = ["iDLG"] tt = transforms.Compose([transforms.ToTensor()]) tp = transforms.Compose([transforms.ToPILImage()]) print(dataset, 'root_path:', root_path) print(dataset, 'data_path:', data_path) print(dataset, 'save_path:', save_path) if not os.path.exists('results'): os.mkdir('results') if not os.path.exists(save_path): os.mkdir(save_path) """ load data """ data_params = DatasetParams() data_params.config(name=dataset, root_path=root_path, data_path=data_path) shape_img = data_params.shape_img num_classes = data_params.num_classes channel = data_params.channel dst = data_params.dst selected_indices = data_params.selected_indices cmap = data_params.cmap # Build ConvNet net = config_net(net_name=net_name, input_shape=(channel, ) + shape_img, num_classes=num_classes) net = net.to(device) # net.eval() # Load model pretrain weights if os.path.isfile(args.model_ckpt): ckpt = torch.load(args.model_ckpt) net.load_state_dict(ckpt) num_success = 0 num_exp = len(selected_indices) criterion = nn.CrossEntropyLoss().to(device) ''' train DLG and iDLG ''' for idx_exp in range(num_exp): print('running %d|%d experiment' % (idx_exp, num_exp)) np.random.seed(idx_exp) # idx_shuffle = np.random.permutation(len(dst)) for method in run_methods: print('%s, Try to generate %d images' % (method, num_dummy)) # criterion = nn.CrossEntropyLoss().to(device) imidx_list = [] # get ground truth image and label idx = selected_indices[idx_exp] imidx_list.append(idx) tmp_datum = tt(dst[idx][0]).float().to(device) tmp_datum = tmp_datum.view(1, *tmp_datum.size()) tmp_label = torch.Tensor([dst[idx][1]]).long().to(device) tmp_label = tmp_label.view(1, ) gt_data = tmp_datum gt_label = tmp_label # compute original gradient out = net(gt_data) y = criterion(out, gt_label) dy_dx = torch.autograd.grad(y, net.parameters()) orig_dy_dx = list((t.detach().clone() for t in dy_dx)) if args.grad_norm: grad_max = [x.max().item() for x in orig_dy_dx] grad_min = [x.min().item() for x in orig_dy_dx] orig_dy_dx = [ (g - g_min) / (g_max - g_min) for g, g_min, g_max in zip(orig_dy_dx, grad_min, grad_max) ] # generate dummy data and label torch.manual_seed(10) dummy_data = torch.randn( gt_data.size()).to(device).requires_grad_(True) dummy_label = torch.randn( (gt_data.shape[0], num_classes)).to(device).requires_grad_(True) # truncated dummy image and label if args.add_clamp: dummy_data.data = torch.clamp(dummy_data + 0.5, 0, 1) dummy_label.data = torch.clamp(dummy_label + 0.5, 0, 1) if method == 'DLG': # optim_obj = [dummy_data, dummy_label] optimizer = torch.optim.LBFGS( [{ 'params': [dummy_data, dummy_label], 'initial_lr': initial_lr }], lr=initial_lr, max_iter=50, tolerance_grad=1e-9, tolerance_change=1e-11, history_size=250, line_search_fn="strong_wolfe") elif method == 'iDLG': # optim_obj = [dummy_data, ] optimizer = torch.optim.LBFGS([{ 'params': [dummy_data], 'initial_lr': initial_lr }], lr=initial_lr, max_iter=50, tolerance_grad=1e-9, tolerance_change=1e-11, history_size=250, line_search_fn="strong_wolfe") # predict the ground-truth label label_pred = torch.argmin(torch.sum(orig_dy_dx[-2], dim=-1), dim=-1).detach().reshape( (1, )).requires_grad_(False) history = [] history_iters = [] losses = [] mses = [] train_iters = [] scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.95, last_epoch=-1) print('lr =', initial_lr) for iters in range(Iteration): closure = functools.partial(_closure, optimizer=optimizer, dummy_data=dummy_data, dummy_label=dummy_label, label_pred=label_pred, method=method, criterion=criterion, net=net, orig_dy_dx=orig_dy_dx, grad_norm=args.grad_norm) optimizer.step(closure) # pixel value clamp if args.add_clamp: dummy_data.data = torch.clamp(dummy_data, 0, 1) current_loss = closure().item() train_iters.append(iters) losses.append(current_loss) mses.append(torch.mean((dummy_data - gt_data)**2).item()) scheduler.step() if iters % plot_steps == 0: current_time = get_current_time() print(current_time, iters, 'loss = %.8f, mse = %.8f' % (current_loss, mses[-1])) history.append([ tp(dummy_data[imidx].cpu()) for imidx in range(num_dummy) ]) history_iters.append(iters) for imidx in range(num_dummy): plot_dummy_x(imidx, cmap, tp, gt_data, history, history_iters, save_path, method, selected_indices, idx_exp) # if current_loss < 0.000001: # converge if mses[-1] < 1e-4: break if mses[-1] < 1e-3: num_success += 1 # Save mse curve plot_mse_curve(mses, iters, save_path, method, selected_indices, idx_exp) if method == 'DLG': loss_DLG = losses label_DLG = torch.argmax(dummy_label, dim=-1).detach().item() mse_DLG = mses elif method == 'iDLG': loss_iDLG = losses label_iDLG = label_pred.item() mse_iDLG = mses print('gt_label:', gt_label.detach().cpu().data.numpy()) if "DLG" in run_methods: print('loss_DLG:', loss_DLG[-1], 'mse_DLG:', mse_DLG[-1], 'lab_DLG:', label_DLG) if "iDLG" in run_methods: print('loss_iDLG:', loss_iDLG[-1], 'mse_iDLG:', mse_iDLG[-1], 'lab_iDLG:', label_iDLG) print('----------------------\n\n') print("Number of successful recover:", num_success)
def main(args): if args.output_file: file_name = args.output_file else: if args.environment is not None: environment = args.environment else: environment = os.getenv('HOSTNAME') if environment is None: environment = 'unknown_host' file_name = 'image_file_list-{}-{}.txt'.format(args.dataset, environment) file_name = os.path.join(args.output_path, file_name) os.makedirs(args.output_path, exist_ok=True) # If we want to generate multiple files we need to add "_X_of_Y" string to the file # to indicate which file out of the set it is: if args.num_files > 1: file_name_list = [] for i in range(args.num_files): file_name_i = os.path.splitext( file_name)[0] + '_{}_of_{}.txt'.format(i + 1, args.num_files) file_name_list.append(file_name_i) else: file_name_list = None # Check that we are not overwriting anything if os.path.exists(file_name): print( 'ERROR: {} exists, please remove it first if you really want to replace it.' .format(file_name)) sys.exit(1) dataset_configs = DatasetParams(args.dataset_config_file) dataset_params = dataset_configs.get_params(args.dataset) for d in dataset_params: # Tell dataset to output full image paths instead of image id: d.config_dict['return_full_image_path'] = True # We ask it to iterate over images instead of all (image, caption) pairs data_loader, _ = get_loader(dataset_params, vocab=None, transform=None, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, ext_feature_sets=None, skip_images=True, iter_over_images=True) print("Getting file paths from dataset {}...".format(args.dataset)) show_progress = sys.stderr.isatty() for i, (_, _, _, paths, _) in enumerate(tqdm(data_loader, disable=not show_progress)): if args.num_files == 1: _file_name = file_name else: n = int(i * data_loader.batch_size * args.num_files / len(data_loader.dataset)) _file_name = file_name_list[n] with open(_file_name, 'a') as f: for path in paths: f.write(path + '\n') # Print log info if not show_progress and ((i + 1) % args.log_step == 0): print('Batch [{}/{}]'.format(i + 1, len(data_loader))) sys.stdout.flush() print("Written paths to {} image files".format(len(data_loader.dataset)))
def main(args): if args.model_name is not None: print('Preparing to train model: {}'.format(args.model_name)) global device device = torch.device( 'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu') sc_will_happen = args.self_critical_from_epoch != -1 if args.validate is None and args.lr_scheduler == 'ReduceLROnPlateau': print( 'ERROR: you need to enable validation in order to use default lr_scheduler (ReduceLROnPlateau)' ) print('Hint: use something like --validate=coco:val2017') sys.exit(1) # Create model directory if not os.path.exists(args.model_path): os.makedirs(args.model_path) # Image preprocessing, normalization for the pretrained resnet transform = transforms.Compose([ # transforms.Resize((256, 256)), transforms.RandomCrop(args.crop_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) scorers = {} if args.validation_scoring is not None or sc_will_happen: assert not ( args.validation_scoring is None and sc_will_happen ), "Please provide a metric when using self-critical training" for s in args.validation_scoring.split(','): s = s.lower().strip() if s == 'cider': from eval.cider import Cider scorers['CIDEr'] = Cider() if s == 'ciderd': from eval.ciderD.ciderD import CiderD scorers['CIDEr-D'] = CiderD(df=args.cached_words) ######################## # Set Model parameters # ######################## # Store parameters gotten from arguments separately: arg_params = ModelParams.fromargs(args) print("Model parameters inferred from command arguments: ") print(arg_params) start_epoch = 0 ############################### # Load existing model state # # and update Model parameters # ############################### state = None if args.load_model: try: state = torch.load(args.load_model, map_location=device) except AttributeError: print( 'WARNING: Old model found. Please use model_update.py in the model before executing this script.' ) exit(1) new_external_features = arg_params.features.external params = ModelParams(state, arg_params=arg_params) if len(new_external_features ) and params.features.external != new_external_features: print('WARNING: external features changed: ', params.features.external, new_external_features) print('Updating feature paths...') params.update_ext_features(new_external_features) start_epoch = state['epoch'] print('Loaded model {} at epoch {}'.format(args.load_model, start_epoch)) else: params = arg_params params.command_history = [] if params.rnn_hidden_init == 'from_features' and params.skip_start_token: print( "ERROR: Please remove --skip_start_token if you want to use image features " " to initialize hidden and cell states. <start> token is needed to trigger " " the process of sequence generation, since we don't have image features " " embedding as the first input token.") sys.exit(1) # Force set the following hierarchical model parameters every time: if arg_params.hierarchical_model: params.hierarchical_model = True params.max_sentences = arg_params.max_sentences params.weight_sentence_loss = arg_params.weight_sentence_loss params.weight_word_loss = arg_params.weight_word_loss params.dropout_stopping = arg_params.dropout_stopping params.dropout_fc = arg_params.dropout_fc params.coherent_sentences = arg_params.coherent_sentences params.coupling_alpha = arg_params.coupling_alpha params.coupling_beta = arg_params.coupling_beta assert args.replace or \ not os.path.isdir(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) or \ not (args.load_model and not args.validate_only), \ '{} already exists. If you want to replace it or resume training please use --replace flag. ' \ 'If you want to validate a loaded model without training it, use --validate_only flag.' \ 'Otherwise specify a different model name using --model_name flag.'\ .format(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) if args.load_model: print("Final model parameters (loaded model + command arguments): ") print(params) ############################## # Load dataset configuration # ############################## dataset_configs = DatasetParams(args.dataset_config_file) if args.dataset is None and not args.validate_only: print('ERROR: No dataset selected!') print( 'Please supply a training dataset with the argument --dataset DATASET' ) print('The following datasets are configured in {}:'.format( args.dataset_config_file)) for ds, _ in dataset_configs.config.items(): if ds not in ('DEFAULT', 'generic'): print(' ', ds) sys.exit(1) if args.validate_only: if args.load_model is None: print( 'ERROR: for --validate_only you need to specify a model to evaluate using --load_model MODEL' ) sys.exit(1) else: dataset_params = dataset_configs.get_params(args.dataset) for i in dataset_params: i.config_dict['no_tokenize'] = args.no_tokenize i.config_dict['show_tokens'] = args.show_tokens i.config_dict['skip_start_token'] = params.skip_start_token if params.hierarchical_model: i.config_dict['hierarchical_model'] = True i.config_dict['max_sentences'] = params.max_sentences i.config_dict['crop_regions'] = False if args.validate is not None: validation_dataset_params = dataset_configs.get_params(args.validate) for i in validation_dataset_params: i.config_dict['no_tokenize'] = args.no_tokenize i.config_dict['show_tokens'] = args.show_tokens i.config_dict['skip_start_token'] = params.skip_start_token if params.hierarchical_model: i.config_dict['hierarchical_model'] = True i.config_dict['max_sentences'] = params.max_sentences i.config_dict['crop_regions'] = False ####################### # Load the vocabulary # ####################### # For pre-trained models attempt to obtain # saved vocabulary from the model itself: if args.load_model and params.vocab is not None: print("Loading vocabulary from the model file:") vocab = params.vocab else: if args.vocab is None: print( "ERROR: You must specify the vocabulary to be used for training using " "--vocab flag.\nTry --vocab AUTO if you want the vocabulary to be " "either generated from the training dataset or loaded from cache." ) sys.exit(1) print("Loading / generating vocabulary:") vocab = get_vocab(args, dataset_params) print('Size of the vocabulary is {}'.format(len(vocab))) ########################## # Initialize data loader # ########################## ext_feature_sets = [ params.features.external, params.persist_features.external ] if not args.validate_only: print('Loading dataset: {} with {} workers'.format( args.dataset, args.num_workers)) if params.skip_start_token: print("Skipping the use of <start> token...") data_loader, ef_dims = get_loader( dataset_params, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers, ext_feature_sets=ext_feature_sets, skip_images=not params.has_internal_features(), verbose=args.verbose, unique_ids=sc_will_happen) if sc_will_happen: gts_sc = get_ground_truth_captions(data_loader.dataset) gts_sc_valid = None if args.validate is not None: valid_loader, ef_dims = get_loader( validation_dataset_params, vocab, transform, args.batch_size, shuffle=True, num_workers=args.num_workers, ext_feature_sets=ext_feature_sets, skip_images=not params.has_internal_features(), verbose=args.verbose) gts_sc_valid = get_ground_truth_captions( valid_loader.dataset) if sc_will_happen else None ######################################### # Setup (optional) TensorBoardX logging # ######################################### writer = None if args.tensorboard: if SummaryWriter is not None: model_name = get_model_name(args, params) timestamp = datetime.now().strftime('%Y%m%d%H%M%S') log_dir = os.path.join( args.output_root, 'log_tb/{}_{}'.format(model_name, timestamp)) writer = SummaryWriter(log_dir=log_dir) print("INFO: Logging TensorBoardX events to {}".format(log_dir)) else: print( "WARNING: SummaryWriter object not available. " "Hint: Please install TensorBoardX using pip install tensorboardx" ) ###################### # Build the model(s) # ###################### # Set per parameter learning rate here, if supplied by the user: if args.lr_word_decoder is not None: if not params.hierarchical_model: print( "ERROR: Setting word decoder learning rate currently supported in Hierarchical Model only." ) sys.exit(1) lr_dict = {'word_decoder': args.lr_word_decoder} else: lr_dict = {} model = EncoderDecoder(params, device, len(vocab), state, ef_dims, lr_dict=lr_dict) ###################### # Optimizer and loss # ###################### sc_activated = False opt_params = model.get_opt_params() # Loss and optimizer if params.hierarchical_model: criterion = HierarchicalXEntropyLoss( weight_sentence_loss=params.weight_sentence_loss, weight_word_loss=params.weight_word_loss) elif args.share_embedding_weights: criterion = SharedEmbeddingXentropyLoss(param_lambda=0.15) else: criterion = nn.CrossEntropyLoss() if sc_will_happen: # save it for later if args.self_critical_loss == 'sc': from model.loss import SelfCriticalLoss rl_criterion = SelfCriticalLoss() elif args.self_critical_loss == 'sc_with_diversity': from model.loss import SelfCriticalWithDiversityLoss rl_criterion = SelfCriticalWithDiversityLoss() elif args.self_critical_loss == 'sc_with_relative_diversity': from model.loss import SelfCriticalWithRelativeDiversityLoss rl_criterion = SelfCriticalWithRelativeDiversityLoss() elif args.self_critical_loss == 'sc_with_bleu_diversity': from model.loss import SelfCriticalWithBLEUDiversityLoss rl_criterion = SelfCriticalWithBLEUDiversityLoss() elif args.self_critical_loss == 'sc_with_repetition': from model.loss import SelfCriticalWithRepetitionLoss rl_criterion = SelfCriticalWithRepetitionLoss() elif args.self_critical_loss == 'mixed': from model.loss import MixedLoss rl_criterion = MixedLoss() elif args.self_critical_loss == 'mixed_with_face': from model.loss import MixedWithFACELoss rl_criterion = MixedWithFACELoss(vocab_size=len(vocab)) elif args.self_critical_loss in [ 'sc_with_penalty', 'sc_with_penalty_throughout', 'sc_masked_tokens' ]: raise ValueError('Deprecated loss, use \'sc\' loss') else: raise ValueError('Invalid self-critical loss') print('Selected self-critical loss is', rl_criterion) if start_epoch >= args.self_critical_from_epoch: criterion = rl_criterion sc_activated = True print('Self-critical loss training begins') # When using CyclicalLR, default learning rate should be always 1.0 if args.lr_scheduler == 'CyclicalLR': default_lr = 1. else: default_lr = 0.001 if sc_activated: optimizer = torch.optim.Adam( opt_params, lr=args.learning_rate if args.learning_rate else 5e-5, weight_decay=args.weight_decay) elif args.optimizer == 'adam': optimizer = torch.optim.Adam(opt_params, lr=default_lr, weight_decay=args.weight_decay) elif args.optimizer == 'rmsprop': optimizer = torch.optim.RMSprop(opt_params, lr=default_lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = torch.optim.SGD(opt_params, lr=default_lr, weight_decay=args.weight_decay) else: print('ERROR: unknown optimizer:', args.optimizer) sys.exit(1) # We don't want to initialize the optimizer if we are transfering # the language model from the regular model to hierarchical model transfer_language_model = False if arg_params.hierarchical_model and state and not state.get( 'hierarchical_model'): transfer_language_model = True # Set optimizer state to the one found in a loaded model, unless # we are doing a transfer learning step from flat to hierarchical model, # or we are using self-critical loss, # or the number of unique parameter groups has changed, or the user # has explicitly told us *not to* reuse optimizer parameters from before if state and not transfer_language_model and not sc_activated and not args.optimizer_reset: # Check that number of parameter groups is the same if len(optimizer.param_groups) == len( state['optimizer']['param_groups']): optimizer.load_state_dict(state['optimizer']) # override lr if set explicitly in arguments - # 1) Global learning rate: if args.learning_rate: for param_group in optimizer.param_groups: param_group['lr'] = args.learning_rate params.learning_rate = args.learning_rate else: params.learning_rate = default_lr # 2) Parameter-group specific learning rate: if args.lr_word_decoder is not None: # We want to give user an option to set learning rate for word_decoder # separately. Other exceptions can be added as needed: for param_group in optimizer.param_groups: if param_group.get('name') == 'word_decoder': param_group['lr'] = args.lr_word_decoder break if args.validate is not None and args.lr_scheduler == 'ReduceLROnPlateau': print('Using ReduceLROnPlateau learning rate scheduler') scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=2) elif args.lr_scheduler == 'StepLR': print('Using StepLR learning rate scheduler with step_size {}'.format( args.lr_step_size)) # Decrease the learning rate by the factor of gamma at every # step_size epochs (for example every 5 or 10 epochs): step_size = args.lr_step_size scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.5, last_epoch=-1) elif args.lr_scheduler == 'CyclicalLR': print( "Using Cyclical learning rate scheduler, lr range: [{},{}]".format( args.lr_cyclical_min, args.lr_cyclical_max)) step_size = len(data_loader) clr = cyclical_lr(step_size, min_lr=args.lr_cyclical_min, max_lr=args.lr_cyclical_max) n_groups = len(optimizer.param_groups) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, [clr] * n_groups) elif args.lr_scheduler is not None: print('ERROR: Invalid learing rate scheduler specified: {}'.format( args.lr_scheduler)) sys.exit(1) ################### # Train the model # ################### stats_postfix = None if args.validate_only: stats_postfix = args.validate if args.load_model: all_stats = init_stats(args, params, postfix=stats_postfix) else: all_stats = {} if args.force_epoch: start_epoch = args.force_epoch - 1 if not args.validate_only: total_step = len(data_loader) print( 'Start training with start_epoch={:d} num_epochs={:d} num_batches={:d} ...' .format(start_epoch, args.num_epochs, args.num_batches)) if args.teacher_forcing != 'always': print('\t k: {}'.format(args.teacher_forcing_k)) print('\t beta: {}'.format(args.teacher_forcing_beta)) print('Optimizer:', optimizer) if args.validate_only: stats = {} teacher_p = 1.0 if args.teacher_forcing != 'always': print( 'WARNING: teacher_forcing!=always, not yet implemented for --validate_only mode' ) epoch = start_epoch - 1 if str(epoch + 1) in all_stats.keys() and args.skip_existing_validations: print('WARNING: epoch {} already validated, skipping...'.format( epoch + 1)) return val_loss = do_validate(model, valid_loader, criterion, scorers, vocab, teacher_p, args, params, stats, epoch, sc_activated, gts_sc_valid) all_stats[str(epoch + 1)] = stats save_stats(args, params, all_stats, postfix=stats_postfix) else: for epoch in range(start_epoch, args.num_epochs): stats = {} begin = datetime.now() total_loss = 0 if params.hierarchical_model: total_loss_sent = 0 total_loss_word = 0 num_batches = 0 vocab_counts = { 'cnt': 0, 'max': 0, 'min': 9999, 'sum': 0, 'unk_cnt': 0, 'unk_sum': 0 } # If start self critical training if not sc_activated and sc_will_happen and epoch >= args.self_critical_from_epoch: if all_stats: best_ep, best_cider = max( [(ep, all_stats[ep]['validation_cider']) for ep in all_stats], key=lambda x: x[1]) print('Loading model from epoch', best_ep, 'which has the better score with', best_cider) state = torch.load( get_model_path(args, params, int(best_ep))) model = EncoderDecoder(params, device, len(vocab), state, ef_dims, lr_dict=lr_dict) opt_params = model.get_opt_params() optimizer = torch.optim.Adam(opt_params, lr=5e-5, weight_decay=args.weight_decay) criterion = rl_criterion print('Self-critical loss training begins') sc_activated = True for i, data in enumerate(data_loader): if params.hierarchical_model: (images, captions, lengths, image_ids, features, sorting_order, last_sentence_indicator) = data sorting_order = sorting_order.to(device) else: (images, captions, lengths, image_ids, features) = data if epoch == 0: unk = vocab('<unk>') for j in range(captions.shape[0]): # Flatten the caption in case it's a paragraph # this is harmless for regular captions too: xl = captions[j, :].view(-1) xw = xl > unk xu = xl == unk xwi = sum(xw).item() xui = sum(xu).item() vocab_counts['cnt'] += 1 vocab_counts['sum'] += xwi vocab_counts['max'] = max(vocab_counts['max'], xwi) vocab_counts['min'] = min(vocab_counts['min'], xwi) vocab_counts['unk_cnt'] += xui > 0 vocab_counts['unk_sum'] += xui # Set mini-batch dataset images = images.to(device) captions = captions.to(device) # Remove <start> token from targets if we are initializing the RNN # hidden state from image features: if params.rnn_hidden_init == 'from_features' and not params.hierarchical_model: # Subtract one from all lengths to match new target lengths: lengths = [x - 1 if x > 0 else x for x in lengths] targets = pack_padded_sequence(captions[:, 1:], lengths, batch_first=True)[0] else: if params.hierarchical_model: targets = prepare_hierarchical_targets( last_sentence_indicator, args.max_sentences, lengths, captions, device) else: targets = pack_padded_sequence(captions, lengths, batch_first=True)[0] sorting_order = None init_features = features[0].to(device) if len( features) > 0 and features[0] is not None else None persist_features = features[1].to(device) if len( features) > 1 and features[1] is not None else None # Forward, backward and optimize # Calculate the probability whether to use teacher forcing or not: # Iterate over batches: iteration = (epoch - start_epoch) * len(data_loader) + i teacher_p = get_teacher_prob(args.teacher_forcing_k, iteration, args.teacher_forcing_beta) # Allow model to log values at the last batch of the epoch writer_data = None if writer and (i == len(data_loader) - 1 or i == args.num_batches - 1): writer_data = {'writer': writer, 'epoch': epoch + 1} sample_len = captions.size(1) if args.self_critical_loss in [ 'mixed', 'mixed_with_face' ] else 20 if sc_activated: sampled_seq, sampled_log_probs, outputs = model.sample( images, init_features, persist_features, max_seq_length=sample_len, start_token_id=vocab('<start>'), trigram_penalty_alpha=args.trigram_penalty_alpha, stochastic_sampling=True, output_logprobs=True, output_outputs=True) sampled_seq = model.decoder.alt_prob_to_tensor( sampled_seq, device=device) else: outputs = model(images, init_features, captions, lengths, persist_features, teacher_p, args.teacher_forcing, sorting_order, writer_data=writer_data) if args.share_embedding_weights: # Weights of (HxH) projection matrix used for regularizing # models that share embedding weights projection = model.decoder.projection.weight loss = criterion(projection, outputs, targets) elif sc_activated: # get greedy decoding baseline model.eval() with torch.no_grad(): greedy_sampled_seq = model.sample( images, init_features, persist_features, max_seq_length=sample_len, start_token_id=vocab('<start>'), trigram_penalty_alpha=args.trigram_penalty_alpha, stochastic_sampling=False) greedy_sampled_seq = model.decoder.alt_prob_to_tensor( greedy_sampled_seq, device=device) model.train() if args.self_critical_loss in [ 'sc', 'sc_with_diversity', 'sc_with_relative_diversity', 'sc_with_bleu_diversity', 'sc_with_repetition' ]: loss, advantage = criterion( sampled_seq, sampled_log_probs, greedy_sampled_seq, [gts_sc[i] for i in image_ids], scorers, vocab, return_advantage=True) elif args.self_critical_loss in ['mixed']: loss, advantage = criterion( sampled_seq, sampled_log_probs, outputs, greedy_sampled_seq, [gts_sc[i] for i in image_ids], scorers, vocab, targets, lengths, gamma_ml_rl=args.gamma_ml_rl, return_advantage=True) elif args.self_critical_loss in ['mixed_with_face']: loss, advantage = criterion( sampled_seq, sampled_log_probs, outputs, greedy_sampled_seq, [gts_sc[i] for i in image_ids], scorers, vocab, captions, targets, lengths, gamma_ml_rl=args.gamma_ml_rl, return_advantage=True) else: raise ValueError('Invalid self-critical loss') if writer is not None and i % 100 == 0: writer.add_scalar('training_loss', loss.item(), epoch * len(data_loader) + i) writer.add_scalar('advantage', advantage, epoch * len(data_loader) + i) writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch * len(data_loader) + i) else: loss = criterion(outputs, targets) model.zero_grad() loss.backward() # Clip gradients if desired: if args.grad_clip is not None: # grad_norms = [x.grad.data.norm(2) for x in opt_params] # batch_max_grad = np.max(grad_norms) # if batch_max_grad > 10.0: # print('WARNING: gradient norms larger than 10.0') # torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.1) # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 0.1) clip_gradients(optimizer, args.grad_clip) # Update weights: optimizer.step() # CyclicalLR requires us to update LR at every minibatch: if args.lr_scheduler == 'CyclicalLR': scheduler.step() total_loss += loss.item() num_batches += 1 if params.hierarchical_model: _, loss_sent, _, loss_word = criterion.item_terms() total_loss_sent += float(loss_sent) total_loss_word += float(loss_word) # Print log info if (i + 1) % args.log_step == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, ' 'Perplexity: {:5.4f}'.format(epoch + 1, args.num_epochs, i + 1, total_step, loss.item(), np.exp(loss.item()))) sys.stdout.flush() if params.hierarchical_model: weight_sent, loss_sent, weight_word, loss_word = criterion.item_terms( ) print('Sentence Loss: {:.4f}, ' 'Word Loss: {:.4f}'.format( float(loss_sent), float(loss_word))) sys.stdout.flush() if i + 1 == args.num_batches: break end = datetime.now() stats['training_loss'] = total_loss / num_batches if params.hierarchical_model: stats['loss_sentence'] = total_loss_sent / num_batches stats['loss_word'] = total_loss_word / num_batches print('Epoch {} duration: {}, average loss: {:.4f}'.format( epoch + 1, end - begin, stats['training_loss'])) save_model(args, params, model.encoder, model.decoder, optimizer, epoch, vocab) if epoch == 0: vocab_counts['avg'] = vocab_counts['sum'] / vocab_counts['cnt'] vocab_counts['unk_cnt_per'] = 100 * vocab_counts[ 'unk_cnt'] / vocab_counts['cnt'] vocab_counts['unk_sum_per'] = 100 * vocab_counts[ 'unk_sum'] / vocab_counts['sum'] # print(vocab_counts) print(( 'Training data contains {sum} words in {cnt} captions (avg. {avg:.1f} w/c)' + ' with {unk_sum} <unk>s ({unk_sum_per:.1f}%)' + ' in {unk_cnt} ({unk_cnt_per:.1f}%) captions').format( **vocab_counts)) ############################################ # Validation loss and learning rate update # ############################################ if args.validate is not None and (epoch + 1) % args.validation_step == 0: val_loss = do_validate(model, valid_loader, criterion, scorers, vocab, teacher_p, args, params, stats, epoch, sc_activated, gts_sc_valid) if args.lr_scheduler == 'ReduceLROnPlateau': scheduler.step(val_loss) elif args.lr_scheduler == 'StepLR': scheduler.step() all_stats[str(epoch + 1)] = stats save_stats(args, params, all_stats, writer=writer) if writer is not None: # Log model data to tensorboard log_model_data(params, model, epoch + 1, writer) if writer is not None: writer.close()