Esempio n. 1
0
def main(args):
    dataset_configs = DatasetParams(args.dataset_config_file)
    dataset_params_1 = dataset_configs.get_params(args.dataset_1)
    dataset_params_2 = dataset_configs.get_params(args.dataset_2)

    for p in (dataset_params_1, dataset_params_2):
        for d in p:
            # Tell dataset to output id in integer or other simple format:
            d.config_dict['return_simple_image_id'] = True

    data_loader_1, _ = get_loader(dataset_params_1,
                                  vocab=None,
                                  transform=None,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  ext_feature_sets=None,
                                  skip_images=True,
                                  iter_over_images=True)

    data_loader_2, _ = get_loader(dataset_params_2,
                                  vocab=None,
                                  transform=None,
                                  batch_size=args.batch_size,
                                  shuffle=False,
                                  num_workers=args.num_workers,
                                  ext_feature_sets=None,
                                  skip_images=True,
                                  iter_over_images=True)

    show_progress = sys.stderr.isatty()
    print("Reading image ids from dataset {}".format(args.dataset_1))
    ids_1 = [
        img_ids for _, _, _, img_ids, _ in tqdm(data_loader_1,
                                                disable=not show_progress)
    ]
    set_1 = set(chain(*ids_1))

    print("Reading image ids from dataset {}".format(args.dataset_2))
    ids_2 = [
        img_ids for _, _, _, img_ids, _ in tqdm(data_loader_2,
                                                disable=not show_progress)
    ]
    set_2 = set(chain(*ids_2))

    intersection = set_1.intersection(set_2)
    len_intersect = len(intersection)

    print("There are {} images shared between {} and {}".format(
        len_intersect, args.dataset_1, args.dataset_2))
Esempio n. 2
0
def get_feature_dims(state, args):
    dataset_configs = DatasetParams(args.dataset_config_file)
    dataset_params = dataset_configs.get_params(args.dataset)[0]

    features_paths = dataset_params.features_path

    ext_feature_sets = [
        state['features'].external, state['persist_features'].external
    ]
    loaders_and_dims = [
        ExternalFeature.loaders(fs, features_paths) for fs in ext_feature_sets
    ]

    loaders, dims = zip(*loaders_and_dims)
    return dims
Esempio n. 3
0
def check_dataset(args):
    if args.dataset is None:
        print('ERROR: No dataset selected!')
        print(
            'Please supply a training dataset with the argument --dataset DATASET'
        )
        print('The following datasets are configured in {}:'.format(
            args.dataset_config_file))

        dataset_configs = DatasetParams(args.dataset_config_file)
        for ds, _ in dataset_configs.config.items():
            if ds not in ('DEFAULT', 'generic'):
                print(' ', ds)

        sys.exit(1)

    return args.dataset
Esempio n. 4
0
    def infer(self, args):
        # print('infer() :', args)

        if 'image_features' not in args:
            args['image_features'] = None

        # Image preprocessing
        transform = transforms.Compose([
            transforms.Resize((args['resize'], args['resize'])),
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))])

        # Get dataset parameters:
        dataset_configs = DatasetParams(args['dataset_config_file'])
        dataset_params = dataset_configs.get_params(args['dataset'],
                                                    args['image_dir'],
                                                    args['image_files'],
                                                    args['image_features'])

        if self.params.has_external_features() and \
           any(dc.name == 'generic' for dc in dataset_params):
            print('WARNING: you cannot use external features without specifying all datasets in '
                  'datasets.conf.')
            print('Hint: take a look at datasets/datasets.conf.default.')

        # Build data loader
        print("Loading dataset: {}".format(args['dataset']))

        # Update dataset params with needed model params:
        for i in dataset_params:
            i.config_dict['skip_start_token'] = self.params.skip_start_token
            # For visualizing attention we need file names instead of IDs in our output:
            if args['store_image_paths']:
                i.config_dict['return_image_file_name'] = True

        ext_feature_sets = [self.params.features.external, self.params.persist_features.external]
        if args['dataset']=='incore':
            ext_feature_sets = None
        
        # We ask it to iterate over images instead of all (image, caption) pairs
        data_loader, ef_dims = get_loader(dataset_params, vocab=None, transform=transform,
                                          batch_size=args['batch_size'], shuffle=False,
                                          num_workers=args['num_workers'],
                                          ext_feature_sets=ext_feature_sets,
                                          skip_images=not self.params.has_internal_features(),
                                          iter_over_images=True)

        self.data_loader = data_loader
        # Create model directory
        if not os.path.exists(args['results_path']):
            os.makedirs(args['results_path'])

        scorers = {}
        if args['scoring'] is not None:
            for s in args['scoring'].split(','):
                s = s.lower().strip()
                if s == 'cider':
                    from eval.cider import Cider
                    scorers['CIDEr'] = Cider(df='corpus')

        # Store captions here:
        output_data = []

        gts = {}
        res = {}

        print('Starting inference, max sentence length: {} num_workers: {}'.\
              format(args['max_seq_length'], args['num_workers']))
        show_progress = sys.stderr.isatty() and not args['verbose'] \
                        and ext_feature_sets is not None

        for i, (images, ref_captions, lengths, image_ids,
                features) in enumerate(tqdm(self.data_loader, disable=not show_progress)):

            if len(scorers) > 0:
                for j in range(len(ref_captions)):
                    jid = image_ids[j]
                    if jid not in gts:
                        gts[jid] = []
                    rcs = ref_captions[j]
                    if type(rcs) is str:
                        rcs = [rcs]
                    for rc in rcs:
                        gts[jid].append(rc.lower())

            images = images.to(device)

            init_features    = features[0].to(device) if len(features) > 0 and \
                               features[0] is not None else None
            persist_features = features[1].to(device) if len(features) > 1 and \
                               features[1] is not None else None

            # Generate a caption from the image
            sampled_batch = self.model.sample(images, init_features, persist_features,
                                              max_seq_length=args['max_seq_length'],
                                              start_token_id=self.vocab('<start>'),
                                              end_token_id=self.vocab('<end>'),
                                              alternatives=args['alternatives'],
                                              probabilities=args['probabilities'])

            sampled_ids_batch = sampled_batch

            for i in range(len(sampled_ids_batch)):
                sampled_ids = sampled_ids_batch[i]

                # Convert word_ids to words
                if self.params.hierarchical_model:
                    # assert False, 'paragraph_ids_to_words() need to be updated'
                    caption = paragraph_ids_to_words(sampled_ids, self.vocab,
                                                     skip_start_token=True)
                else:
                    caption = caption_ids_ext_to_words(sampled_ids, self.vocab,
                                                       skip_start_token=True,
                                                       capitalize=not args['no_capitalize'])
                if args['no_repeat_sentences']:
                    caption = remove_duplicate_sentences(caption)

                if args['only_complete_sentences']:
                    caption = remove_incomplete_sentences(caption)

                if args['verbose']:
                    print('=>', caption)

                if True:
                    caption = self.apply_lemma_pos_rules(caption)
                    if args['verbose']:
                        print('#>', caption)
                    
                output_data.append({'image_id': image_ids[i],
                                    'caption': caption})
                res[image_ids[i]] = [caption.lower()]

        for score_name, scorer in scorers.items():
            score = scorer.compute_score(gts, res)[0]
            print('Test', score_name, score)

        # Decide output format, fall back to txt
        if args['output_format'] is not None:
            output_format = args['output_format']
        elif args['output_file'] and args['output_file'].endswith('.json'):
            output_format = 'json'
        else:
            output_format = 'txt'

        # Create a sensible default output path for results:
        output_file = None
        if not args['output_file'] and not args['print_results']:
            model_name_path = Path(args['model'])
            is_in_same_folder = len(model_name_path.parents) == 1
            if not is_in_same_folder:
                model_name = args['model'].split(os.sep)[-2]
                model_epoch = basename(args['model'])
                output_file = '{}-{}.{}'.format(model_name, model_epoch,
                                                output_format)
            else:
                output_file = model_name_path.stem + '.' + output_format
        else:
            output_file = args['output_file']

        if output_file:
            output_path = os.path.join(args['results_path'], output_file)
            if output_format == 'json':
                json.dump(output_data, open(output_path, 'w'))
            else:
                with open(output_path, 'w') as fp:
                    for data in output_data:
                        print(data['image_id'], data['caption'], file=fp)

            print('Wrote generated captions to {} as {}'.
                  format(output_path, args['output_format']))

        if args['print_results']:
            for d in output_data:
                print('{}: {}'.format(d['image_id'], d['caption']))

        return output_data
Esempio n. 5
0
def main(args):
    #
    # Image preprocessing
    if args.feature_type == 'plain':
        if args.extractor == 'resnet152caffe-original':
            # Use custom transform:
            transform = transforms.Compose([
                transforms.Resize((args.crop_size, args.crop_size)),
                # Swap color space from RGB to BGR and subtract caffe-specific
                # channel values from each pixel
                transforms.Lambda(lambda img: np.array(img, dtype=np.float32)[
                    ..., [2, 1, 0]] - [103.939, 116.779, 123.68]),
                # Create a torch tensor and put channels first:
                transforms.Lambda(
                    lambda img: torch.from_numpy(img).permute(2, 0, 1)),
                # Cast tensor to correct type:
                transforms.Lambda(lambda img: img.type('torch.FloatTensor'))
            ])

        else:
            # Default transform
            transform = transforms.Compose([
                transforms.Resize((args.crop_size, args.crop_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225))
            ])
    elif args.feature_type == 'avg' or args.feature_type == 'max':
        # Try with no normalization
        # Try with subtracting 0.5 from all values
        # See example here: https://pytorch.org/docs/stable/torchvision/transforms.html

        if args.normalize == 'default':
            transform = transforms.Compose([
                transforms.Resize((args.image_size, args.image_size)),
                # 10-crop implementation as described in PyTorch documentation:
                transforms.TenCrop((args.crop_size, args.crop_size)),
                # Apply next two transforms to each crop in turn and then stack them
                # to a single tensor:
                transforms.Lambda(lambda crops: torch.stack([
                    transforms.Normalize((0.485, 0.456, 0.406),
                                         (0.229, 0.224, 0.225))
                    (transforms.ToTensor()(crop)) for crop in crops
                ]))
            ])
        elif args.normalize == 'skip':
            transform = transforms.Compose([
                transforms.Resize((args.image_size, args.image_size)),
                transforms.TenCrop((args.crop_size, args.crop_size)),
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops]))
            ])
        elif args.normalize == 'subtract_half':
            transform = transforms.Compose([
                transforms.Resize((args.image_size, args.image_size)),
                transforms.TenCrop((args.crop_size, args.crop_size)),
                transforms.Lambda(lambda crops: torch.stack(
                    [transforms.ToTensor()(crop) for crop in crops]) - 0.5)
            ])
        else:
            print("Invalid normalization parameter")
            sys.exit(1)

    else:
        print("Invalid feature type specified {}".args.feature_type)
        sys.exit(1)

    print("Creating features of type: {}".format(args.feature_type))

    # Get dataset parameters and vocabulary wrapper:
    dataset_configs = DatasetParams(args.dataset_config_file)
    dataset_params = dataset_configs.get_params(args.dataset)

    # We want to only get the image file name, not the full path:
    for i in dataset_params:
        i.config_dict['return_image_file_name'] = True

    # We ask it to iterate over images instead of all (image, caption) pairs
    data_loader, _ = get_loader(dataset_params,
                                vocab=None,
                                transform=transform,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers,
                                ext_feature_sets=None,
                                skip_images=False,
                                iter_over_images=True)

    extractor = FeatureExtractor(args.extractor, True).to(device).eval()

    # To open an lmdb handle and prepare it for the right size
    # it needs to fit the total number of elements in the dataset
    # so we set a map_size to a largish value here:
    map_size = 1e12

    lmdb_path = None
    file_name = None

    if args.output_file:
        file_name = args.output_file
    else:
        file_name = '{}-{}-{}-normalize-{}.lmdb'.format(
            args.dataset, args.extractor, args.feature_type, args.normalize)

    os.makedirs(args.output_dir, exist_ok=True)

    lmdb_path = os.path.join(args.output_dir, file_name)

    # Check that we are not overwriting anything
    if os.path.exists(lmdb_path):
        print(
            'ERROR: {} exists, please remove it first if you really want to replace it.'
            .format(lmdb_path))
        sys.exit(1)

    print("Preparing to store extracted features to {}...".format(lmdb_path))

    print("Starting to extract features from dataset {} using {}...".format(
        args.dataset, args.extractor))
    show_progress = sys.stderr.isatty()

    # If feature shape is not 1-dimensional, store feature shape metadata:
    if isinstance(extractor.output_dim, np.ndarray):
        with lmdb.open(lmdb_path, map_size=map_size) as env:
            with env.begin(write=True) as txn:
                txn.put(str('@vdim').encode('ascii'), extractor.output_dim)

    for i, (images, _, _, image_ids,
            _) in enumerate(tqdm(data_loader, disable=not show_progress)):

        images = images.to(device)

        # If we are dealing with cropped images, image dimensions are: bs, ncrops, c, h, w
        if images.dim() == 5:
            bs, ncrops, c, h, w = images.size()
            # fuse batch size and ncrops:
            raw_features = extractor(images.view(-1, c, h, w))

            if args.feature_type == 'avg':
                # Average over crops:
                features = raw_features.view(bs, ncrops,
                                             -1).mean(1).data.cpu().numpy()
            elif args.feature_type == 'max':
                # Max over crops:
                features = raw_features.view(bs, ncrops,
                                             -1).max(1)[0].data.cpu().numpy()
        # Otherwise our image dimensions are bs, c, h, w
        else:
            features = extractor(images).data.cpu().numpy()

        # Write to LMDB object:
        with lmdb.open(lmdb_path, map_size=map_size) as env:
            with env.begin(write=True) as txn:
                for j, image_id in enumerate(image_ids):

                    # If output dimension is not a scalar, flatten the array.
                    # When retrieving this feature from the LMDB, developer must take
                    # care to reshape the feature back to the correct dimensions!
                    if isinstance(extractor.output_dim, np.ndarray):
                        _feature = features[j].flatten()
                    # Otherwise treat it as is:
                    else:
                        _feature = features[j]

                    txn.put(str(image_id).encode('ascii'), _feature)

        # Print log info
        if not show_progress and ((i + 1) % args.log_step == 0):
            print('Batch [{}/{}]'.format(i + 1, len(data_loader)))
            sys.stdout.flush()
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = \
        ','.join(str(gpu) for gpu in args.visible_gpus)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    dataset = args.dataset
    net_name = args.cnn_name
    root_path = '.'
    data_path = os.path.join(root_path, "data")
    save_path = os.path.join(root_path, "results",
                             "iDLG_%s_%s" % (dataset, net_name))

    if args.add_clamp:
        save_path += "_clamp"

    # Some running setting
    initial_lr = args.lr
    num_dummy = 1
    Iteration = args.max_iter
    plot_steps = args.plot_steps
    # run_methods = ["iDLG", "DLG"]
    run_methods = ["iDLG"]

    tt = transforms.Compose([transforms.ToTensor()])
    tp = transforms.Compose([transforms.ToPILImage()])

    print(dataset, 'root_path:', root_path)
    print(dataset, 'data_path:', data_path)
    print(dataset, 'save_path:', save_path)

    if not os.path.exists('results'):
        os.mkdir('results')
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    """ load data """

    data_params = DatasetParams()
    data_params.config(name=dataset, root_path=root_path, data_path=data_path)

    shape_img = data_params.shape_img
    num_classes = data_params.num_classes
    channel = data_params.channel
    dst = data_params.dst
    selected_indices = data_params.selected_indices
    cmap = data_params.cmap

    # Build ConvNet
    net = config_net(net_name=net_name,
                     input_shape=(channel, ) + shape_img,
                     num_classes=num_classes)
    net = net.to(device)
    # net.eval()

    # Load model pretrain weights
    if os.path.isfile(args.model_ckpt):
        ckpt = torch.load(args.model_ckpt)
        net.load_state_dict(ckpt)

    num_success = 0
    num_exp = len(selected_indices)
    criterion = nn.CrossEntropyLoss().to(device)
    ''' train DLG and iDLG '''
    for idx_exp in range(num_exp):

        print('running %d|%d experiment' % (idx_exp, num_exp))
        np.random.seed(idx_exp)
        # idx_shuffle = np.random.permutation(len(dst))

        for method in run_methods:
            print('%s, Try to generate %d images' % (method, num_dummy))

            # criterion = nn.CrossEntropyLoss().to(device)
            imidx_list = []

            # get ground truth image and label
            idx = selected_indices[idx_exp]
            imidx_list.append(idx)
            tmp_datum = tt(dst[idx][0]).float().to(device)
            tmp_datum = tmp_datum.view(1, *tmp_datum.size())
            tmp_label = torch.Tensor([dst[idx][1]]).long().to(device)
            tmp_label = tmp_label.view(1, )
            gt_data = tmp_datum
            gt_label = tmp_label

            # compute original gradient
            out = net(gt_data)
            y = criterion(out, gt_label)
            dy_dx = torch.autograd.grad(y, net.parameters())
            orig_dy_dx = list((t.detach().clone() for t in dy_dx))
            if args.grad_norm:
                grad_max = [x.max().item() for x in orig_dy_dx]
                grad_min = [x.min().item() for x in orig_dy_dx]
                orig_dy_dx = [
                    (g - g_min) / (g_max - g_min)
                    for g, g_min, g_max in zip(orig_dy_dx, grad_min, grad_max)
                ]

            # generate dummy data and label
            torch.manual_seed(10)
            dummy_data = torch.randn(
                gt_data.size()).to(device).requires_grad_(True)
            dummy_label = torch.randn(
                (gt_data.shape[0],
                 num_classes)).to(device).requires_grad_(True)

            # truncated dummy image and label
            if args.add_clamp:
                dummy_data.data = torch.clamp(dummy_data + 0.5, 0, 1)
                dummy_label.data = torch.clamp(dummy_label + 0.5, 0, 1)

            if method == 'DLG':
                # optim_obj = [dummy_data, dummy_label]
                optimizer = torch.optim.LBFGS(
                    [{
                        'params': [dummy_data, dummy_label],
                        'initial_lr': initial_lr
                    }],
                    lr=initial_lr,
                    max_iter=50,
                    tolerance_grad=1e-9,
                    tolerance_change=1e-11,
                    history_size=250,
                    line_search_fn="strong_wolfe")
            elif method == 'iDLG':
                # optim_obj = [dummy_data, ]
                optimizer = torch.optim.LBFGS([{
                    'params': [dummy_data],
                    'initial_lr': initial_lr
                }],
                                              lr=initial_lr,
                                              max_iter=50,
                                              tolerance_grad=1e-9,
                                              tolerance_change=1e-11,
                                              history_size=250,
                                              line_search_fn="strong_wolfe")
                # predict the ground-truth label
                label_pred = torch.argmin(torch.sum(orig_dy_dx[-2], dim=-1),
                                          dim=-1).detach().reshape(
                                              (1, )).requires_grad_(False)

            history = []
            history_iters = []
            losses = []
            mses = []
            train_iters = []

            scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                        step_size=30,
                                                        gamma=0.95,
                                                        last_epoch=-1)
            print('lr =', initial_lr)
            for iters in range(Iteration):

                closure = functools.partial(_closure,
                                            optimizer=optimizer,
                                            dummy_data=dummy_data,
                                            dummy_label=dummy_label,
                                            label_pred=label_pred,
                                            method=method,
                                            criterion=criterion,
                                            net=net,
                                            orig_dy_dx=orig_dy_dx,
                                            grad_norm=args.grad_norm)

                optimizer.step(closure)

                # pixel value clamp
                if args.add_clamp:
                    dummy_data.data = torch.clamp(dummy_data, 0, 1)

                current_loss = closure().item()
                train_iters.append(iters)
                losses.append(current_loss)
                mses.append(torch.mean((dummy_data - gt_data)**2).item())
                scheduler.step()

                if iters % plot_steps == 0:
                    current_time = get_current_time()
                    print(current_time, iters,
                          'loss = %.8f, mse = %.8f' % (current_loss, mses[-1]))
                    history.append([
                        tp(dummy_data[imidx].cpu())
                        for imidx in range(num_dummy)
                    ])
                    history_iters.append(iters)

                    for imidx in range(num_dummy):
                        plot_dummy_x(imidx, cmap, tp, gt_data, history,
                                     history_iters, save_path, method,
                                     selected_indices, idx_exp)

                    # if current_loss < 0.000001:
                    # converge
                    if mses[-1] < 1e-4:
                        break
            if mses[-1] < 1e-3:
                num_success += 1
            # Save mse curve
            plot_mse_curve(mses, iters, save_path, method, selected_indices,
                           idx_exp)

            if method == 'DLG':
                loss_DLG = losses
                label_DLG = torch.argmax(dummy_label, dim=-1).detach().item()
                mse_DLG = mses
            elif method == 'iDLG':
                loss_iDLG = losses
                label_iDLG = label_pred.item()
                mse_iDLG = mses

        print('gt_label:', gt_label.detach().cpu().data.numpy())
        if "DLG" in run_methods:
            print('loss_DLG:', loss_DLG[-1], 'mse_DLG:', mse_DLG[-1],
                  'lab_DLG:', label_DLG)
        if "iDLG" in run_methods:
            print('loss_iDLG:', loss_iDLG[-1], 'mse_iDLG:', mse_iDLG[-1],
                  'lab_iDLG:', label_iDLG)

        print('----------------------\n\n')

    print("Number of successful recover:", num_success)
Esempio n. 7
0
def main(args):

    if args.output_file:
        file_name = args.output_file
    else:
        if args.environment is not None:
            environment = args.environment
        else:
            environment = os.getenv('HOSTNAME')
            if environment is None:
                environment = 'unknown_host'
        file_name = 'image_file_list-{}-{}.txt'.format(args.dataset,
                                                       environment)

    file_name = os.path.join(args.output_path, file_name)

    os.makedirs(args.output_path, exist_ok=True)

    # If we want to generate multiple files we need to add "_X_of_Y" string to the file
    # to indicate which file out of the set it is:
    if args.num_files > 1:
        file_name_list = []
        for i in range(args.num_files):
            file_name_i = os.path.splitext(
                file_name)[0] + '_{}_of_{}.txt'.format(i + 1, args.num_files)
            file_name_list.append(file_name_i)
    else:
        file_name_list = None

    # Check that we are not overwriting anything
    if os.path.exists(file_name):
        print(
            'ERROR: {} exists, please remove it first if you really want to replace it.'
            .format(file_name))
        sys.exit(1)

    dataset_configs = DatasetParams(args.dataset_config_file)
    dataset_params = dataset_configs.get_params(args.dataset)

    for d in dataset_params:
        # Tell dataset to output full image paths instead of image id:
        d.config_dict['return_full_image_path'] = True

    # We ask it to iterate over images instead of all (image, caption) pairs
    data_loader, _ = get_loader(dataset_params,
                                vocab=None,
                                transform=None,
                                batch_size=args.batch_size,
                                shuffle=False,
                                num_workers=args.num_workers,
                                ext_feature_sets=None,
                                skip_images=True,
                                iter_over_images=True)

    print("Getting file paths from dataset {}...".format(args.dataset))
    show_progress = sys.stderr.isatty()
    for i, (_, _, _, paths,
            _) in enumerate(tqdm(data_loader, disable=not show_progress)):

        if args.num_files == 1:
            _file_name = file_name
        else:
            n = int(i * data_loader.batch_size * args.num_files /
                    len(data_loader.dataset))
            _file_name = file_name_list[n]

        with open(_file_name, 'a') as f:
            for path in paths:
                f.write(path + '\n')

        # Print log info
        if not show_progress and ((i + 1) % args.log_step == 0):
            print('Batch [{}/{}]'.format(i + 1, len(data_loader)))
            sys.stdout.flush()

    print("Written paths to {} image files".format(len(data_loader.dataset)))
Esempio n. 8
0
def main(args):
    if args.model_name is not None:
        print('Preparing to train model: {}'.format(args.model_name))

    global device
    device = torch.device(
        'cuda' if torch.cuda.is_available() and not args.cpu else 'cpu')

    sc_will_happen = args.self_critical_from_epoch != -1

    if args.validate is None and args.lr_scheduler == 'ReduceLROnPlateau':
        print(
            'ERROR: you need to enable validation in order to use default lr_scheduler (ReduceLROnPlateau)'
        )
        print('Hint: use something like --validate=coco:val2017')
        sys.exit(1)

    # Create model directory
    if not os.path.exists(args.model_path):
        os.makedirs(args.model_path)

    # Image preprocessing, normalization for the pretrained resnet
    transform = transforms.Compose([
        # transforms.Resize((256, 256)),
        transforms.RandomCrop(args.crop_size),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    scorers = {}
    if args.validation_scoring is not None or sc_will_happen:
        assert not (
            args.validation_scoring is None and sc_will_happen
        ), "Please provide a metric when using self-critical training"
        for s in args.validation_scoring.split(','):
            s = s.lower().strip()
            if s == 'cider':
                from eval.cider import Cider
                scorers['CIDEr'] = Cider()
            if s == 'ciderd':
                from eval.ciderD.ciderD import CiderD
                scorers['CIDEr-D'] = CiderD(df=args.cached_words)

    ########################
    # Set Model parameters #
    ########################

    # Store parameters gotten from arguments separately:
    arg_params = ModelParams.fromargs(args)

    print("Model parameters inferred from command arguments: ")
    print(arg_params)
    start_epoch = 0

    ###############################
    # Load existing model state   #
    # and update Model parameters #
    ###############################

    state = None

    if args.load_model:
        try:
            state = torch.load(args.load_model, map_location=device)
        except AttributeError:
            print(
                'WARNING: Old model found. Please use model_update.py in the model before executing this script.'
            )
            exit(1)
        new_external_features = arg_params.features.external

        params = ModelParams(state, arg_params=arg_params)
        if len(new_external_features
               ) and params.features.external != new_external_features:
            print('WARNING: external features changed: ',
                  params.features.external, new_external_features)
            print('Updating feature paths...')
            params.update_ext_features(new_external_features)
        start_epoch = state['epoch']
        print('Loaded model {} at epoch {}'.format(args.load_model,
                                                   start_epoch))
    else:
        params = arg_params
        params.command_history = []

    if params.rnn_hidden_init == 'from_features' and params.skip_start_token:
        print(
            "ERROR: Please remove --skip_start_token if you want to use image features "
            " to initialize hidden and cell states. <start> token is needed to trigger "
            " the process of sequence generation, since we don't have image features "
            " embedding as the first input token.")
        sys.exit(1)

    # Force set the following hierarchical model parameters every time:
    if arg_params.hierarchical_model:
        params.hierarchical_model = True
        params.max_sentences = arg_params.max_sentences
        params.weight_sentence_loss = arg_params.weight_sentence_loss
        params.weight_word_loss = arg_params.weight_word_loss
        params.dropout_stopping = arg_params.dropout_stopping
        params.dropout_fc = arg_params.dropout_fc
        params.coherent_sentences = arg_params.coherent_sentences
        params.coupling_alpha = arg_params.coupling_alpha
        params.coupling_beta = arg_params.coupling_beta

    assert args.replace or \
        not os.path.isdir(os.path.join(args.output_root, args.model_path, get_model_name(args, params))) or \
        not (args.load_model and not args.validate_only), \
        '{} already exists. If you want to replace it or resume training please use --replace flag. ' \
        'If you want to validate a loaded model without training it, use --validate_only flag.'  \
        'Otherwise specify a different model name using --model_name flag.'\
        .format(os.path.join(args.output_root, args.model_path, get_model_name(args, params)))

    if args.load_model:
        print("Final model parameters (loaded model + command arguments): ")
        print(params)

    ##############################
    # Load dataset configuration #
    ##############################

    dataset_configs = DatasetParams(args.dataset_config_file)

    if args.dataset is None and not args.validate_only:
        print('ERROR: No dataset selected!')
        print(
            'Please supply a training dataset with the argument --dataset DATASET'
        )
        print('The following datasets are configured in {}:'.format(
            args.dataset_config_file))
        for ds, _ in dataset_configs.config.items():
            if ds not in ('DEFAULT', 'generic'):
                print(' ', ds)
        sys.exit(1)

    if args.validate_only:
        if args.load_model is None:
            print(
                'ERROR: for --validate_only you need to specify a model to evaluate using --load_model MODEL'
            )
            sys.exit(1)
    else:
        dataset_params = dataset_configs.get_params(args.dataset)

        for i in dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    if args.validate is not None:
        validation_dataset_params = dataset_configs.get_params(args.validate)
        for i in validation_dataset_params:
            i.config_dict['no_tokenize'] = args.no_tokenize
            i.config_dict['show_tokens'] = args.show_tokens
            i.config_dict['skip_start_token'] = params.skip_start_token

            if params.hierarchical_model:
                i.config_dict['hierarchical_model'] = True
                i.config_dict['max_sentences'] = params.max_sentences
                i.config_dict['crop_regions'] = False

    #######################
    # Load the vocabulary #
    #######################

    # For pre-trained models attempt to obtain
    # saved vocabulary from the model itself:
    if args.load_model and params.vocab is not None:
        print("Loading vocabulary from the model file:")
        vocab = params.vocab
    else:
        if args.vocab is None:
            print(
                "ERROR: You must specify the vocabulary to be used for training using "
                "--vocab flag.\nTry --vocab AUTO if you want the vocabulary to be "
                "either generated from the training dataset or loaded from cache."
            )
            sys.exit(1)
        print("Loading / generating vocabulary:")
        vocab = get_vocab(args, dataset_params)

    print('Size of the vocabulary is {}'.format(len(vocab)))

    ##########################
    # Initialize data loader #
    ##########################

    ext_feature_sets = [
        params.features.external, params.persist_features.external
    ]
    if not args.validate_only:
        print('Loading dataset: {} with {} workers'.format(
            args.dataset, args.num_workers))
        if params.skip_start_token:
            print("Skipping the use of <start> token...")
        data_loader, ef_dims = get_loader(
            dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose,
            unique_ids=sc_will_happen)
        if sc_will_happen:
            gts_sc = get_ground_truth_captions(data_loader.dataset)

    gts_sc_valid = None
    if args.validate is not None:
        valid_loader, ef_dims = get_loader(
            validation_dataset_params,
            vocab,
            transform,
            args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            ext_feature_sets=ext_feature_sets,
            skip_images=not params.has_internal_features(),
            verbose=args.verbose)
        gts_sc_valid = get_ground_truth_captions(
            valid_loader.dataset) if sc_will_happen else None

    #########################################
    # Setup (optional) TensorBoardX logging #
    #########################################

    writer = None
    if args.tensorboard:
        if SummaryWriter is not None:
            model_name = get_model_name(args, params)
            timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
            log_dir = os.path.join(
                args.output_root, 'log_tb/{}_{}'.format(model_name, timestamp))
            writer = SummaryWriter(log_dir=log_dir)
            print("INFO: Logging TensorBoardX events to {}".format(log_dir))
        else:
            print(
                "WARNING: SummaryWriter object not available. "
                "Hint: Please install TensorBoardX using pip install tensorboardx"
            )

    ######################
    # Build the model(s) #
    ######################

    # Set per parameter learning rate here, if supplied by the user:

    if args.lr_word_decoder is not None:
        if not params.hierarchical_model:
            print(
                "ERROR: Setting word decoder learning rate currently supported in Hierarchical Model only."
            )
            sys.exit(1)

        lr_dict = {'word_decoder': args.lr_word_decoder}
    else:
        lr_dict = {}

    model = EncoderDecoder(params,
                           device,
                           len(vocab),
                           state,
                           ef_dims,
                           lr_dict=lr_dict)

    ######################
    # Optimizer and loss #
    ######################

    sc_activated = False
    opt_params = model.get_opt_params()

    # Loss and optimizer
    if params.hierarchical_model:
        criterion = HierarchicalXEntropyLoss(
            weight_sentence_loss=params.weight_sentence_loss,
            weight_word_loss=params.weight_word_loss)
    elif args.share_embedding_weights:
        criterion = SharedEmbeddingXentropyLoss(param_lambda=0.15)
    else:
        criterion = nn.CrossEntropyLoss()

    if sc_will_happen:  # save it for later
        if args.self_critical_loss == 'sc':
            from model.loss import SelfCriticalLoss
            rl_criterion = SelfCriticalLoss()
        elif args.self_critical_loss == 'sc_with_diversity':
            from model.loss import SelfCriticalWithDiversityLoss
            rl_criterion = SelfCriticalWithDiversityLoss()
        elif args.self_critical_loss == 'sc_with_relative_diversity':
            from model.loss import SelfCriticalWithRelativeDiversityLoss
            rl_criterion = SelfCriticalWithRelativeDiversityLoss()
        elif args.self_critical_loss == 'sc_with_bleu_diversity':
            from model.loss import SelfCriticalWithBLEUDiversityLoss
            rl_criterion = SelfCriticalWithBLEUDiversityLoss()
        elif args.self_critical_loss == 'sc_with_repetition':
            from model.loss import SelfCriticalWithRepetitionLoss
            rl_criterion = SelfCriticalWithRepetitionLoss()
        elif args.self_critical_loss == 'mixed':
            from model.loss import MixedLoss
            rl_criterion = MixedLoss()
        elif args.self_critical_loss == 'mixed_with_face':
            from model.loss import MixedWithFACELoss
            rl_criterion = MixedWithFACELoss(vocab_size=len(vocab))
        elif args.self_critical_loss in [
                'sc_with_penalty', 'sc_with_penalty_throughout',
                'sc_masked_tokens'
        ]:
            raise ValueError('Deprecated loss, use \'sc\' loss')
        else:
            raise ValueError('Invalid self-critical loss')

        print('Selected self-critical loss is', rl_criterion)

        if start_epoch >= args.self_critical_from_epoch:
            criterion = rl_criterion
            sc_activated = True
            print('Self-critical loss training begins')

    # When using CyclicalLR, default learning rate should be always 1.0
    if args.lr_scheduler == 'CyclicalLR':
        default_lr = 1.
    else:
        default_lr = 0.001

    if sc_activated:
        optimizer = torch.optim.Adam(
            opt_params,
            lr=args.learning_rate if args.learning_rate else 5e-5,
            weight_decay=args.weight_decay)
    elif args.optimizer == 'adam':
        optimizer = torch.optim.Adam(opt_params,
                                     lr=default_lr,
                                     weight_decay=args.weight_decay)
    elif args.optimizer == 'rmsprop':
        optimizer = torch.optim.RMSprop(opt_params,
                                        lr=default_lr,
                                        weight_decay=args.weight_decay)
    elif args.optimizer == 'sgd':
        optimizer = torch.optim.SGD(opt_params,
                                    lr=default_lr,
                                    weight_decay=args.weight_decay)
    else:
        print('ERROR: unknown optimizer:', args.optimizer)
        sys.exit(1)

    # We don't want to initialize the optimizer if we are transfering
    # the language model from the regular model to hierarchical model
    transfer_language_model = False

    if arg_params.hierarchical_model and state and not state.get(
            'hierarchical_model'):
        transfer_language_model = True

    # Set optimizer state to the one found in a loaded model, unless
    # we are doing a transfer learning step from flat to hierarchical model,
    # or we are using self-critical loss,
    # or the number of unique parameter groups has changed, or the user
    # has explicitly told us *not to* reuse optimizer parameters from before
    if state and not transfer_language_model and not sc_activated and not args.optimizer_reset:
        # Check that number of parameter groups is the same
        if len(optimizer.param_groups) == len(
                state['optimizer']['param_groups']):
            optimizer.load_state_dict(state['optimizer'])

    # override lr if set explicitly in arguments -
    # 1) Global learning rate:
    if args.learning_rate:
        for param_group in optimizer.param_groups:
            param_group['lr'] = args.learning_rate
        params.learning_rate = args.learning_rate
    else:
        params.learning_rate = default_lr

    # 2) Parameter-group specific learning rate:
    if args.lr_word_decoder is not None:
        # We want to give user an option to set learning rate for word_decoder
        # separately. Other exceptions can be added as needed:
        for param_group in optimizer.param_groups:
            if param_group.get('name') == 'word_decoder':
                param_group['lr'] = args.lr_word_decoder
                break

    if args.validate is not None and args.lr_scheduler == 'ReduceLROnPlateau':
        print('Using ReduceLROnPlateau learning rate scheduler')
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               'min',
                                                               verbose=True,
                                                               patience=2)
    elif args.lr_scheduler == 'StepLR':
        print('Using StepLR learning rate scheduler with step_size {}'.format(
            args.lr_step_size))
        # Decrease the learning rate by the factor of gamma at every
        # step_size epochs (for example every 5 or 10 epochs):
        step_size = args.lr_step_size
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                    step_size,
                                                    gamma=0.5,
                                                    last_epoch=-1)
    elif args.lr_scheduler == 'CyclicalLR':
        print(
            "Using Cyclical learning rate scheduler, lr range: [{},{}]".format(
                args.lr_cyclical_min, args.lr_cyclical_max))

        step_size = len(data_loader)
        clr = cyclical_lr(step_size,
                          min_lr=args.lr_cyclical_min,
                          max_lr=args.lr_cyclical_max)
        n_groups = len(optimizer.param_groups)
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      [clr] * n_groups)
    elif args.lr_scheduler is not None:
        print('ERROR: Invalid learing rate scheduler specified: {}'.format(
            args.lr_scheduler))
        sys.exit(1)

    ###################
    # Train the model #
    ###################

    stats_postfix = None
    if args.validate_only:
        stats_postfix = args.validate
    if args.load_model:
        all_stats = init_stats(args, params, postfix=stats_postfix)
    else:
        all_stats = {}

    if args.force_epoch:
        start_epoch = args.force_epoch - 1

    if not args.validate_only:
        total_step = len(data_loader)
        print(
            'Start training with start_epoch={:d} num_epochs={:d} num_batches={:d} ...'
            .format(start_epoch, args.num_epochs, args.num_batches))

    if args.teacher_forcing != 'always':
        print('\t k: {}'.format(args.teacher_forcing_k))
        print('\t beta: {}'.format(args.teacher_forcing_beta))
    print('Optimizer:', optimizer)

    if args.validate_only:
        stats = {}
        teacher_p = 1.0
        if args.teacher_forcing != 'always':
            print(
                'WARNING: teacher_forcing!=always, not yet implemented for --validate_only mode'
            )

        epoch = start_epoch - 1
        if str(epoch +
               1) in all_stats.keys() and args.skip_existing_validations:
            print('WARNING: epoch {} already validated, skipping...'.format(
                epoch + 1))
            return

        val_loss = do_validate(model, valid_loader, criterion, scorers, vocab,
                               teacher_p, args, params, stats, epoch,
                               sc_activated, gts_sc_valid)
        all_stats[str(epoch + 1)] = stats
        save_stats(args, params, all_stats, postfix=stats_postfix)
    else:
        for epoch in range(start_epoch, args.num_epochs):
            stats = {}
            begin = datetime.now()

            total_loss = 0

            if params.hierarchical_model:
                total_loss_sent = 0
                total_loss_word = 0

            num_batches = 0
            vocab_counts = {
                'cnt': 0,
                'max': 0,
                'min': 9999,
                'sum': 0,
                'unk_cnt': 0,
                'unk_sum': 0
            }

            # If start self critical training
            if not sc_activated and sc_will_happen and epoch >= args.self_critical_from_epoch:
                if all_stats:
                    best_ep, best_cider = max(
                        [(ep, all_stats[ep]['validation_cider'])
                         for ep in all_stats],
                        key=lambda x: x[1])
                    print('Loading model from epoch', best_ep,
                          'which has the better score with', best_cider)
                    state = torch.load(
                        get_model_path(args, params, int(best_ep)))
                    model = EncoderDecoder(params,
                                           device,
                                           len(vocab),
                                           state,
                                           ef_dims,
                                           lr_dict=lr_dict)
                    opt_params = model.get_opt_params()

                optimizer = torch.optim.Adam(opt_params,
                                             lr=5e-5,
                                             weight_decay=args.weight_decay)
                criterion = rl_criterion
                print('Self-critical loss training begins')
                sc_activated = True

            for i, data in enumerate(data_loader):

                if params.hierarchical_model:
                    (images, captions, lengths, image_ids, features,
                     sorting_order, last_sentence_indicator) = data
                    sorting_order = sorting_order.to(device)
                else:
                    (images, captions, lengths, image_ids, features) = data

                if epoch == 0:
                    unk = vocab('<unk>')
                    for j in range(captions.shape[0]):
                        # Flatten the caption in case it's a paragraph
                        # this is harmless for regular captions too:
                        xl = captions[j, :].view(-1)
                        xw = xl > unk
                        xu = xl == unk
                        xwi = sum(xw).item()
                        xui = sum(xu).item()
                        vocab_counts['cnt'] += 1
                        vocab_counts['sum'] += xwi
                        vocab_counts['max'] = max(vocab_counts['max'], xwi)
                        vocab_counts['min'] = min(vocab_counts['min'], xwi)
                        vocab_counts['unk_cnt'] += xui > 0
                        vocab_counts['unk_sum'] += xui
                # Set mini-batch dataset
                images = images.to(device)
                captions = captions.to(device)

                # Remove <start> token from targets if we are initializing the RNN
                # hidden state from image features:
                if params.rnn_hidden_init == 'from_features' and not params.hierarchical_model:
                    # Subtract one from all lengths to match new target lengths:
                    lengths = [x - 1 if x > 0 else x for x in lengths]
                    targets = pack_padded_sequence(captions[:, 1:],
                                                   lengths,
                                                   batch_first=True)[0]
                else:
                    if params.hierarchical_model:
                        targets = prepare_hierarchical_targets(
                            last_sentence_indicator, args.max_sentences,
                            lengths, captions, device)
                    else:
                        targets = pack_padded_sequence(captions,
                                                       lengths,
                                                       batch_first=True)[0]
                        sorting_order = None

                init_features = features[0].to(device) if len(
                    features) > 0 and features[0] is not None else None
                persist_features = features[1].to(device) if len(
                    features) > 1 and features[1] is not None else None

                # Forward, backward and optimize
                # Calculate the probability whether to use teacher forcing or not:

                # Iterate over batches:
                iteration = (epoch - start_epoch) * len(data_loader) + i

                teacher_p = get_teacher_prob(args.teacher_forcing_k, iteration,
                                             args.teacher_forcing_beta)

                # Allow model to log values at the last batch of the epoch
                writer_data = None
                if writer and (i == len(data_loader) - 1
                               or i == args.num_batches - 1):
                    writer_data = {'writer': writer, 'epoch': epoch + 1}

                sample_len = captions.size(1) if args.self_critical_loss in [
                    'mixed', 'mixed_with_face'
                ] else 20
                if sc_activated:
                    sampled_seq, sampled_log_probs, outputs = model.sample(
                        images,
                        init_features,
                        persist_features,
                        max_seq_length=sample_len,
                        start_token_id=vocab('<start>'),
                        trigram_penalty_alpha=args.trigram_penalty_alpha,
                        stochastic_sampling=True,
                        output_logprobs=True,
                        output_outputs=True)
                    sampled_seq = model.decoder.alt_prob_to_tensor(
                        sampled_seq, device=device)
                else:
                    outputs = model(images,
                                    init_features,
                                    captions,
                                    lengths,
                                    persist_features,
                                    teacher_p,
                                    args.teacher_forcing,
                                    sorting_order,
                                    writer_data=writer_data)

                if args.share_embedding_weights:
                    # Weights of (HxH) projection matrix used for regularizing
                    # models that share embedding weights
                    projection = model.decoder.projection.weight
                    loss = criterion(projection, outputs, targets)
                elif sc_activated:
                    # get greedy decoding baseline
                    model.eval()
                    with torch.no_grad():
                        greedy_sampled_seq = model.sample(
                            images,
                            init_features,
                            persist_features,
                            max_seq_length=sample_len,
                            start_token_id=vocab('<start>'),
                            trigram_penalty_alpha=args.trigram_penalty_alpha,
                            stochastic_sampling=False)
                        greedy_sampled_seq = model.decoder.alt_prob_to_tensor(
                            greedy_sampled_seq, device=device)
                    model.train()

                    if args.self_critical_loss in [
                            'sc', 'sc_with_diversity',
                            'sc_with_relative_diversity',
                            'sc_with_bleu_diversity', 'sc_with_repetition'
                    ]:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    elif args.self_critical_loss in ['mixed_with_face']:
                        loss, advantage = criterion(
                            sampled_seq,
                            sampled_log_probs,
                            outputs,
                            greedy_sampled_seq, [gts_sc[i] for i in image_ids],
                            scorers,
                            vocab,
                            captions,
                            targets,
                            lengths,
                            gamma_ml_rl=args.gamma_ml_rl,
                            return_advantage=True)
                    else:
                        raise ValueError('Invalid self-critical loss')

                    if writer is not None and i % 100 == 0:
                        writer.add_scalar('training_loss', loss.item(),
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('advantage', advantage,
                                          epoch * len(data_loader) + i)
                        writer.add_scalar('lr',
                                          optimizer.param_groups[0]['lr'],
                                          epoch * len(data_loader) + i)
                else:
                    loss = criterion(outputs, targets)

                model.zero_grad()
                loss.backward()

                # Clip gradients if desired:
                if args.grad_clip is not None:
                    # grad_norms = [x.grad.data.norm(2) for x in opt_params]
                    # batch_max_grad = np.max(grad_norms)
                    # if batch_max_grad > 10.0:
                    #     print('WARNING: gradient norms larger than 10.0')

                    # torch.nn.utils.clip_grad_norm_(decoder.parameters(), 0.1)
                    # torch.nn.utils.clip_grad_norm_(encoder.parameters(), 0.1)
                    clip_gradients(optimizer, args.grad_clip)

                # Update weights:
                optimizer.step()

                # CyclicalLR requires us to update LR at every minibatch:
                if args.lr_scheduler == 'CyclicalLR':
                    scheduler.step()

                total_loss += loss.item()

                num_batches += 1

                if params.hierarchical_model:
                    _, loss_sent, _, loss_word = criterion.item_terms()
                    total_loss_sent += float(loss_sent)
                    total_loss_word += float(loss_word)

                # Print log info
                if (i + 1) % args.log_step == 0:
                    print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, '
                          'Perplexity: {:5.4f}'.format(epoch + 1,
                                                       args.num_epochs, i + 1,
                                                       total_step, loss.item(),
                                                       np.exp(loss.item())))
                    sys.stdout.flush()

                    if params.hierarchical_model:
                        weight_sent, loss_sent, weight_word, loss_word = criterion.item_terms(
                        )
                        print('Sentence Loss: {:.4f}, '
                              'Word Loss: {:.4f}'.format(
                                  float(loss_sent), float(loss_word)))
                        sys.stdout.flush()

                if i + 1 == args.num_batches:
                    break

            end = datetime.now()

            stats['training_loss'] = total_loss / num_batches

            if params.hierarchical_model:
                stats['loss_sentence'] = total_loss_sent / num_batches
                stats['loss_word'] = total_loss_word / num_batches

            print('Epoch {} duration: {}, average loss: {:.4f}'.format(
                epoch + 1, end - begin, stats['training_loss']))

            save_model(args, params, model.encoder, model.decoder, optimizer,
                       epoch, vocab)

            if epoch == 0:
                vocab_counts['avg'] = vocab_counts['sum'] / vocab_counts['cnt']
                vocab_counts['unk_cnt_per'] = 100 * vocab_counts[
                    'unk_cnt'] / vocab_counts['cnt']
                vocab_counts['unk_sum_per'] = 100 * vocab_counts[
                    'unk_sum'] / vocab_counts['sum']
                # print(vocab_counts)
                print((
                    'Training data contains {sum} words in {cnt} captions (avg. {avg:.1f} w/c)'
                    + ' with {unk_sum} <unk>s ({unk_sum_per:.1f}%)' +
                    ' in {unk_cnt} ({unk_cnt_per:.1f}%) captions').format(
                        **vocab_counts))

            ############################################
            # Validation loss and learning rate update #
            ############################################

            if args.validate is not None and (epoch +
                                              1) % args.validation_step == 0:
                val_loss = do_validate(model, valid_loader, criterion, scorers,
                                       vocab, teacher_p, args, params, stats,
                                       epoch, sc_activated, gts_sc_valid)

                if args.lr_scheduler == 'ReduceLROnPlateau':
                    scheduler.step(val_loss)
            elif args.lr_scheduler == 'StepLR':
                scheduler.step()

            all_stats[str(epoch + 1)] = stats
            save_stats(args, params, all_stats, writer=writer)

            if writer is not None:
                # Log model data to tensorboard
                log_model_data(params, model, epoch + 1, writer)

    if writer is not None:
        writer.close()