def validate(args, device_id, pt, step):
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    if (pt != ''):
        test_from = pt
    else:
        test_from = args.test_from
    logger.info('Loading checkpoint from %s' % test_from)
    checkpoint = torch.load(test_from, map_location=lambda storage, loc: storage)
    opt = vars(checkpoint['opt'])
    for k in opt.keys():
        if (k in model_flags):
            setattr(args, k, opt[k])
    print(args)

    model = AbsSummarizer(args, device, checkpoint)
    model.eval()

    valid_iter = data_loader.Dataloader(args, load_dataset(args, 'valid', shuffle=False),
                                        args.batch_size, device,
                                        shuffle=False, is_test=False)

    tokenizer = BertData(args).tokenizer
    #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
    # tokenizer= None
    # if args.pretrained_model_type in ['bert-base-uncased', 'bert-base-multilingual-uncased']:
    #     tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_type, do_lower_case=True, cache_dir=args.temp_dir)
    #
    # if not tokenizer:
    #     raise NotImplementedError("tokenizer")

    #tokenizer = add_to_vocab(tokenizer, ['[unused0]', '[unused1]', '[PAD]', '[unused2]'])
    symbols = {'BOS': tokenizer.convert_tokens_to_ids('[unused0]'), 'EOS': tokenizer.convert_tokens_to_ids('[unused1]'),
               'PAD': tokenizer.convert_tokens_to_ids('[PAD]'), 'EOQ': tokenizer.convert_tokens_to_ids('[unused2]')}

    valid_loss = abs_loss(model.generator, symbols, model.vocab_size, train=False, device=device)

    trainer = build_trainer(args, device_id, model, None, valid_loss)
    stats = trainer.validate(valid_iter, step)
    return stats.xent()
def train_abs_single(args, device_id):
    init_logger(args.log_file)
    logger.info(str(args))
    device = "cpu" if args.visible_gpus == '-1' else "cuda"
    logger.info('Device ID %d' % device_id)
    logger.info('Device %s' % device)
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    if device_id >= 0:
        torch.cuda.set_device(device_id)
        torch.cuda.manual_seed(args.seed)

    if args.train_from != '':
        logger.info('Loading checkpoint from %s' % args.train_from)
        checkpoint = torch.load(args.train_from,
                                map_location=lambda storage, loc: storage)
        opt = vars(checkpoint['opt'])
        for k in opt.keys():
            if (k in model_flags):
                setattr(args, k, opt[k])
    else:
        checkpoint = None

    if (args.load_from_extractive != ''):
        logger.info('Loading bert from extractive model %s' % args.load_from_extractive)
        bert_from_extractive = torch.load(args.load_from_extractive, map_location=lambda storage, loc: storage)
        bert_from_extractive = bert_from_extractive['model']
    else:
        bert_from_extractive = None
    torch.manual_seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    def train_iter_fct():
        return data_loader.Dataloader(args, load_dataset(args, 'train', shuffle=True), args.batch_size, device,
                                      shuffle=True, is_test=False)

    model = AbsSummarizer(args, device, checkpoint, bert_from_extractive)
    if (args.sep_optim):
        optim_bert = model_builder.build_optim_bert(args, model, checkpoint)
        optim_dec = model_builder.build_optim_dec(args, model, checkpoint)
        optim = [optim_bert, optim_dec]
    else:
        optim = [model_builder.build_optim(args, model, checkpoint)]

    logger.info(model)

    tokenizer = BertData(args).tokenizer

    #tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True, cache_dir=args.temp_dir)
    # tokenizer = None
    # if args.pretrained_model_type in ['bert-base-uncased', 'bert-base-multilingual-uncased']:
    #     tokenizer = BertTokenizer.from_pretrained(args.pretrained_model_type, do_lower_case=True, cache_dir=args.temp_dir)
    #
    # if not tokenizer:
    #     raise NotImplementedError("tokenizer")

    #tokenizer = add_to_vocab(tokenizer, ['[unused0]', '[unused1]', '[PAD]','[unused2]'])
    symbols = {'BOS': tokenizer.convert_tokens_to_ids('[unused0]'), 'EOS': tokenizer.convert_tokens_to_ids('[unused1]'),
               'PAD': tokenizer.convert_tokens_to_ids('[PAD]'), 'EOQ': tokenizer.convert_tokens_to_ids('[unused2]')}

    train_loss = abs_loss(model.generator, symbols, model.vocab_size, device, train=True,
                          label_smoothing=args.label_smoothing)

    trainer = build_trainer(args, device_id, model, optim, train_loss)

    trainer.train(train_iter_fct, args.train_steps)