def main(args):
    ontology = json.load(open(os.path.join(args.data_root, args.ontology_data)))
    slot_meta, _ = make_slot_meta(ontology)
    tokenizer = BertTokenizer.from_pretrained("dsksd/bert-ko-small-minimal")

    out_file = '/opt/ml/code/p3-dst-chatting-day/SomDST/pickles/test_data_raw.pkl'
    if os.path.exists(out_file):
        print("Pickles are exist!")
        with open(out_file, 'rb') as f:
            test_data_raw = pickle.load(f)
        # with open(out_path+'/test_data.pkl', 'rb') as f:
        #     test_data = pickle.load(f)
        print("Pickles brought!")
    else:
        print("Pickles are not exist!")
        test_data_raw = prepare_dataset_eval(data_path=args.test_data_path,
                                        tokenizer=tokenizer,
                                        slot_meta=slot_meta,
                                        n_history=args.n_history,
                                        max_seq_length=args.max_seq_length,
                                        op_code=args.op_code)

        # test_data = WosDataset(train_data_raw,
        #                             tokenizer,
        #                             slot_meta,
        #                             args.max_seq_length,
        #                             rng,
        #                             ontology,
        #                             args.word_dropout,
        #                             args.shuffle_state,
        #                             args.shuffle_p)

        with open(out_file, 'wb') as f:
            pickle.dump(test_data_raw, f)
        # with open(out_path+'/test_data.pkl', 'wb') as f:
        #     pickle.dump(test_data, f)
        print("Pickles saved!")

    print("# test examples %d" % len(test_data_raw))


    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = 0.1
    op2id = OP_SET[args.op_code]
    model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'])
    ckpt = torch.load(args.model_ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)

    model.eval()
    model.to(device)
    print("Model is loaded")

    inference_model(model, test_data_raw, tokenizer, slot_meta, args.op_code)
示例#2
0
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ontology = json.load(open(os.path.join(args.data_root,
                                           args.ontology_data)))
    slot_meta, _ = make_slot_meta(ontology)

    tokenizer = BertTokenizer.from_pretrained(args.bert_config)
    special_tokens = ['[SLOT]', '[NULL]']
    special_tokens_dict = {'additional_special_tokens': special_tokens}
    tokenizer.add_special_tokens(special_tokens_dict)

    data = prepare_dataset(data_path=os.path.join(args.data_root,
                                                  args.test_data),
                           data_list=None,
                           tokenizer=tokenizer,
                           slot_meta=slot_meta,
                           n_history=args.n_history,
                           max_seq_length=args.max_seq_length,
                           op_code=args.op_code)

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = 0.1
    op2id = OP_SET[args.op_code]
    model = TransformerDST(model_config, len(op2id), len(domain2id),
                           op2id['update'])
    ckpt = torch.load(args.model_ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)

    model.eval()
    model.to(device)

    if args.eval_all:
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, False, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, False, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, True, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, True, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, False, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, True, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, False, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, True, True)
    else:
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         args.gt_op, args.gt_p_state, args.gt_gen)
示例#3
0
def main(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    ontology = json.load(open(os.path.join(args.data_root,
                                           args.ontology_data)))
    slot_meta, _ = make_slot_meta(ontology)
    tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)
    data = prepare_dataset(os.path.join(args.data_root,
                                        args.test_data), tokenizer, slot_meta,
                           args.n_history, args.max_seq_length, args.op_code)

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = 0.1
    op2id = OP_SET[args.op_code]
    model = TransformerDST(model_config, len(op2id), len(domain2id),
                           op2id['update'])
    ckpt = torch.load(args.model_ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)

    model.eval()
    model.to(device)

    if args.eval_all:
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, False, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, False, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, True, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, True, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, False, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, True, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, False, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, True, True)
    else:
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         args.gt_op, args.gt_p_state, args.gt_gen)
示例#4
0
def main(args):
    ontology = json.load(open(os.path.join(args.data_root,
                                           args.ontology_data)))
    slot_meta, _ = make_slot_meta(ontology)
    tokenizer = BertTokenizer.from_pretrained("dsksd/bert-ko-small-minimal")
    data = prepare_dataset(os.path.join(args.data_root,
                                        args.test_data), tokenizer, slot_meta,
                           args.n_history, args.max_seq_length, args.op_code)

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = 0.1
    op2id = OP_SET[args.op_code]
    model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'])
    ckpt = torch.load(args.model_ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)

    model.eval()
    model.to(device)

    if args.eval_all:
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, False, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, False, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, True, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         False, True, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, False, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, True, False)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, False, True)
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         True, True, True)
    else:
        model_evaluation(model, data, tokenizer, slot_meta, 0, args.op_code,
                         args.gt_op, args.gt_p_state, args.gt_gen)
示例#5
0
def main(args):

    assert args.use_one_optim is True

    if args.use_cls_only:
        args.no_dial = True

    print("### use_cls_only: {:}".format(args.use_cls_only))
    print("### no_dial: {:}".format(args.no_dial))

    if args.recover_e > 0:
        raise NotImplementedError("This option is from my oldest code version. "
                                  "I have not checked it for this code version.")

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
        print("### mkdir {:}".format(args.save_dir))

    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available() and (not args.use_cpu):
        n_gpu = torch.cuda.device_count()
        device = torch.device('cuda')
        print("### Device: {:}".format(device))
    else:
        print("### Use CPU (Debugging)")
        device = torch.device("cpu")

    if args.random_seed < 0:
        print("### Pick a random seed")
        args.random_seed = random.sample(list(range(0, 100000)), 1)[0]

    print("### Random Seed: {:}".format(args.random_seed))
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)

    if n_gpu > 0:
        if args.random_seed >= 0:
            torch.cuda.manual_seed(args.random_seed)
            torch.cuda.manual_seed_all(args.random_seed)

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]
    print(op2id)

    tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)

    train_path = os.path.join(args.data_root, "train.pt")
    dev_path = os.path.join(args.data_root, "dev.pt")
    test_path = os.path.join(args.data_root, "test.pt")

    if not os.path.exists(test_path):
        test_data_raw = prepare_dataset(data_path=args.test_data_path,
                                        tokenizer=tokenizer,
                                        slot_meta=slot_meta,
                                        n_history=args.n_history,
                                        max_seq_length=args.max_seq_length,
                                        op_code=args.op_code)
        torch.save(test_data_raw, test_path)
    else:
        test_data_raw = torch.load(test_path)

    print("# test examples %d" % len(test_data_raw))

    if not os.path.exists(train_path):
        train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                         tokenizer=tokenizer,
                                         slot_meta=slot_meta,
                                         n_history=args.n_history,
                                         max_seq_length=args.max_seq_length,
                                         op_code=args.op_code)

        torch.save(train_data_raw, train_path)
    else:
        train_data_raw = torch.load(train_path)

    train_data = MultiWozDataset(train_data_raw,
                                 tokenizer,
                                 slot_meta,
                                 args.max_seq_length,
                                 rng,
                                 ontology,
                                 args.word_dropout,
                                 args.shuffle_state,
                                 args.shuffle_p, pad_id=tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                                 slot_id=tokenizer.convert_tokens_to_ids(['[SLOT]'])[0],
                                 decoder_teacher_forcing=args.decoder_teacher_forcing,
                                 use_full_slot=args.use_full_slot,
                                 use_dt_only=args.use_dt_only, no_dial=args.no_dial,
                                 use_cls_only=args.use_cls_only)

    print("# train examples %d" % len(train_data_raw))

    if not os.path.exists(dev_path):
        dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                       tokenizer=tokenizer,
                                       slot_meta=slot_meta,
                                       n_history=args.n_history,
                                       max_seq_length=args.max_seq_length,
                                       op_code=args.op_code)
        torch.save(dev_data_raw,  dev_path)
    else:
        dev_data_raw = torch.load(dev_path)

    print("# dev examples %d" % len(dev_data_raw))

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = args.dropout
    model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
    model_config.hidden_dropout_prob = args.hidden_dropout_prob

    type_vocab_size = 4
    dec_config = args
    model = TransformerDST(model_config, dec_config, len(op2id), len(domain2id),
                           op2id['update'],
                           tokenizer.convert_tokens_to_ids(['[MASK]'])[0],
                           tokenizer.convert_tokens_to_ids(['[SEP]'])[0],
                           tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                           tokenizer.convert_tokens_to_ids(['-'])[0],
                           type_vocab_size, args.exclude_domain)

    if not os.path.exists(args.bert_ckpt_path):
        args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')

    state_dict = torch.load(args.bert_ckpt_path, map_location='cpu')
    _k = 'embeddings.token_type_embeddings.weight'
    print("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format(
            type_vocab_size, state_dict[_k].shape[0]))
    state_dict[_k].resize_(
        type_vocab_size, state_dict[_k].shape[1])
    state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :])
    state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :])
    model.bert.load_state_dict(state_dict)
    print("\n### Done Load BERT")
    sys.stdout.flush()

    # re-initialize added special tokens ([SLOT], [NULL], [EOS])
    model.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)

    # re-initialize seg-2, seg-3
    model.bert.embeddings.token_type_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.token_type_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)
    model.to(device)

    num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs)

    if args.use_one_optim:
        print("### Use One Optim")
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(
                nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.enc_lr)
        scheduler = WarmupLinearSchedule(optimizer, int(num_train_steps * args.enc_warmup),
                                             t_total=num_train_steps)
    else:
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        enc_param_optimizer = list(model.bert.named_parameters())  # TODO: For BERT only
        print('### Optim BERT: {:}'.format(len(enc_param_optimizer)))
        enc_optimizer_grouped_parameters = [
            {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]

        enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
        enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
                                             t_total=num_train_steps)

        dec_param_optimizer = list(model.named_parameters())  # TODO:  For other parameters
        print('### Optim All: {:}'.format(len(dec_param_optimizer)))
        dec_param_optimizer = [p for (n, p) in dec_param_optimizer if 'bert' not in n]
        print('### Optim OTH: {:}'.format(len(dec_param_optimizer)))
        dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
        dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
                                             t_total=num_train_steps)

    if args.recover_e > 0:
        model_recover, enc_recover, dec_recover = load(args, str(args.recover_e))
        print("### Recover Model E{:}".format(args.recover_e))
        sys.stdout.flush()
        model.load_state_dict(model_recover)
        print("### Recover Optim E{:}".format(args.recover_e))
        sys.stdout.flush()
        enc_optimizer.load_state_dict(enc_recover)
        dec_optimizer.load_state_dict(dec_optimizer)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=train_data.collate_fn,
                                  num_workers=args.num_workers,
                                  worker_init_fn=worker_init_fn)

    loss_fnc = nn.CrossEntropyLoss()
    best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}

    start_time = time.time()

    for epoch in range(args.n_epochs):
        batch_loss = []
        model.train()
        for step, batch in enumerate(train_dataloader):

            batch = [b.to(device) if (not isinstance(b, int)) and (not isinstance(b, dict) and (not isinstance(b, list)) and (not isinstance(b, np.ndarray))) else b for b in batch]

            input_ids_p, segment_ids_p, input_mask_p, \
            state_position_ids, op_ids, domain_ids, input_ids_g, segment_ids_g, position_ids_g, input_mask_g, \
            masked_pos, masked_weights, lm_label_ids, id_n_map, gen_max_len, n_total_pred = batch

            domain_scores, state_scores, loss_g = model(input_ids_p, segment_ids_p, input_mask_p, state_position_ids,
                input_ids_g, segment_ids_g, position_ids_g, input_mask_g,
                masked_pos, masked_weights, lm_label_ids, id_n_map, gen_max_len, only_pred_op=args.only_pred_op, n_gpu=n_gpu)

            if n_total_pred > 0:
                loss_g = loss_g.sum() / n_total_pred
            else:
                loss_g = 0

            loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1))

            if args.only_pred_op:
                loss = loss_s
            else:
                loss = loss_s + loss_g

            if args.exclude_domain is not True:
                loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1))
                loss = loss + loss_d

            batch_loss.append(loss.item())

            loss.backward()

            if args.use_one_optim:
                optimizer.step()
                scheduler.step()
            else:
                enc_optimizer.step()
                enc_scheduler.step()
                dec_optimizer.step()
                dec_scheduler.step()

            model.zero_grad()

            if step % 100 == 0:
                try:
                    loss_g = loss_g.item()
                except AttributeError:
                    loss_g = loss_g

                if args.exclude_domain is not True:
                    print("time %.1f min, [%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
                          % ((time.time()-start_time)/60, epoch+1, args.n_epochs, step,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g, loss_d.item()))
                else:
                    print("time %.1f min, [%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
                          % ((time.time()-start_time)/60, epoch+1, args.n_epochs, step,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g))

                sys.stdout.flush()
                batch_loss = []

        if args.use_one_optim:
            save(args, epoch + 1, model, optimizer)
        else:
            save(args, epoch + 1, model, enc_optimizer, dec_optimizer)

        if ((epoch+1) % args.eval_epoch == 0) and (epoch+1 >= 8):
            eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code,
                                        use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, use_cls_only=args.use_cls_only, n_gpu=n_gpu)
            print("### Epoch {:} Score : ".format(epoch+1), eval_res)

            if eval_res['joint_acc'] > best_score['joint_acc']:
                best_score = eval_res
                print("### Best Joint Acc: {:} ###".format(best_score['joint_acc']))
                print('\n')

                if epoch+1 >= 8:  # To speed up
                    eval_res_test = model_evaluation(model, test_data_raw, tokenizer, slot_meta, epoch + 1, args.op_code,
                                                     use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, use_cls_only=args.use_cls_only, n_gpu=n_gpu)
                    print("### Epoch {:} Test Score : ".format(epoch + 1), eval_res_test)
示例#6
0
def main(args):
    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available():
        n_gpu = torch.cuda.device_count()
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]
    print(op2id)
    tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)

    train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                     tokenizer=tokenizer,
                                     slot_meta=slot_meta,
                                     n_history=args.n_history,
                                     max_seq_length=args.max_seq_length,
                                     op_code=args.op_code)

    train_data = MultiWozDataset(train_data_raw,
                                 tokenizer,
                                 slot_meta,
                                 args.max_seq_length,
                                 rng,
                                 ontology,
                                 args.word_dropout,
                                 args.shuffle_state,
                                 args.shuffle_p)
    print("# train examples %d" % len(train_data_raw))

    dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                   tokenizer=tokenizer,
                                   slot_meta=slot_meta,
                                   n_history=args.n_history,
                                   max_seq_length=args.max_seq_length,
                                   op_code=args.op_code)
    print("# dev examples %d" % len(dev_data_raw))

    test_data_raw = prepare_dataset(data_path=args.test_data_path,
                                    tokenizer=tokenizer,
                                    slot_meta=slot_meta,
                                    n_history=args.n_history,
                                    max_seq_length=args.max_seq_length,
                                    op_code=args.op_code)
    print("# test examples %d" % len(test_data_raw))

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = args.dropout
    model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
    model_config.hidden_dropout_prob = args.hidden_dropout_prob

    model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)

    if not os.path.exists(args.bert_ckpt_path):
        args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, 'assets')

    ckpt = torch.load(args.bert_ckpt_path, map_location='cpu')
    model.encoder.bert.load_state_dict(ckpt)

    # re-initialize added special tokens ([SLOT], [NULL], [EOS])
    model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
    model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)
    model.to(device)

    num_train_steps = int(len(train_data_raw) / args.batch_size * args.n_epochs)

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    enc_param_optimizer = list(model.encoder.named_parameters())
    enc_optimizer_grouped_parameters = [
        {'params': [p for n, p in enc_param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
    enc_scheduler = WarmupLinearSchedule(enc_optimizer, int(num_train_steps * args.enc_warmup),
                                         t_total=num_train_steps)

    dec_param_optimizer = list(model.decoder.parameters())
    dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
    dec_scheduler = WarmupLinearSchedule(dec_optimizer, int(num_train_steps * args.dec_warmup),
                                         t_total=num_train_steps)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=train_data.collate_fn,
                                  num_workers=args.num_workers,
                                  worker_init_fn=worker_init_fn)

    loss_fnc = nn.CrossEntropyLoss()
    best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}
    for epoch in range(args.n_epochs):
        batch_loss = []
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = [b.to(device) if not isinstance(b, int) else b for b in batch]
            input_ids, input_mask, segment_ids, state_position_ids, op_ids,\
            domain_ids, gen_ids, max_value, max_update = batch

            if rng.random() < args.decoder_teacher_forcing:  # teacher forcing
                teacher = gen_ids
            else:
                teacher = None

            domain_scores, state_scores, gen_scores = model(input_ids=input_ids,
                                                            token_type_ids=segment_ids,
                                                            state_positions=state_position_ids,
                                                            attention_mask=input_mask,
                                                            max_value=max_value,
                                                            op_ids=op_ids,
                                                            max_update=max_update,
                                                            teacher=teacher)

            loss_s = loss_fnc(state_scores.view(-1, len(op2id)), op_ids.view(-1))
            loss_g = masked_cross_entropy_for_value(gen_scores.contiguous(),
                                                    gen_ids.contiguous(),
                                                    tokenizer.vocab['[PAD]'])
            loss = loss_s + loss_g
            if args.exclude_domain is not True:
                loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)), domain_ids.view(-1))
                loss = loss + loss_d
            batch_loss.append(loss.item())

            loss.backward()
            enc_optimizer.step()
            enc_scheduler.step()
            dec_optimizer.step()
            dec_scheduler.step()
            model.zero_grad()

            if step % 100 == 0:
                if args.exclude_domain is not True:
                    print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
                          % (epoch+1, args.n_epochs, step,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g.item(), loss_d.item()))
                else:
                    print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
                          % (epoch+1, args.n_epochs, step,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g.item()))
                batch_loss = []

        if (epoch+1) % args.eval_epoch == 0:
            eval_res = model_evaluation(model, dev_data_raw, tokenizer, slot_meta, epoch+1, args.op_code)
            if eval_res['joint_acc'] > best_score['joint_acc']:
                best_score = eval_res
                model_to_save = model.module if hasattr(model, 'module') else model
                save_path = os.path.join(args.save_dir, 'model_best.bin')
                torch.save(model_to_save.state_dict(), save_path)
            print("Best Score : ", best_score)
            print("\n")

    print("Test using best model...")
    best_epoch = best_score['epoch']
    ckpt_path = os.path.join(args.save_dir, 'model_best.bin')
    model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'], args.exclude_domain)
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)
    model.to(device)

    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=False, is_gt_p_state=False, is_gt_gen=False)
    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=False, is_gt_p_state=False, is_gt_gen=True)
    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=False, is_gt_p_state=True, is_gt_gen=False)
    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=False, is_gt_p_state=True, is_gt_gen=True)
    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=True, is_gt_p_state=False, is_gt_gen=False)
    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=True, is_gt_p_state=True, is_gt_gen=False)
    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=True, is_gt_p_state=False, is_gt_gen=True)
    model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                     is_gt_op=True, is_gt_p_state=True, is_gt_gen=True)
示例#7
0
def main(args):

    assert args.use_one_optim is True

    if args.recover_e > 0:
        raise NotImplementedError("This option is from my oldest code version. "
                                  "I have not checked it for this code version.")

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
        print("### mkdir {:}".format(args.save_dir))

    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available() and (not args.use_cpu):
        n_gpu = torch.cuda.device_count()
        device = torch.device('cuda')
        print("### Device: {:}".format(device))
    else:
        print("### Use CPU (Debugging)")
        device = torch.device("cpu")

    if args.random_seed < 0:
        print("### Pick a random seed")
        args.random_seed = random.sample(list(range(1, 100000)), 1)[0]

    print("### Random Seed: {:}".format(args.random_seed))
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)

    if n_gpu > 0:
        if args.random_seed >= 0:
            torch.cuda.manual_seed(args.random_seed)
            torch.cuda.manual_seed_all(args.random_seed)

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]
    print(op2id)

    tokenizer = BertTokenizer(args.vocab_path, do_lower_case=True)

    train_path = os.path.join(args.data_root, "train.pt")
    train_data_raw = torch.load(train_path)[:5000]
    print("# train examples %d" % len(train_data_raw))

    test_path = os.path.join(args.data_root, "test.pt")
    test_data_raw = torch.load(test_path)
    print("# test examples %d" % len(test_data_raw))

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = args.dropout
    model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
    model_config.hidden_dropout_prob = args.hidden_dropout_prob

    type_vocab_size = 4
    dec_config = args
    model = TransformerDST(model_config, dec_config, len(op2id), len(domain2id),
                           op2id['update'],
                           tokenizer.convert_tokens_to_ids(['[MASK]'])[0],
                           tokenizer.convert_tokens_to_ids(['[SEP]'])[0],
                           tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                           tokenizer.convert_tokens_to_ids(['-'])[0],
                           type_vocab_size, args.exclude_domain)

    test_epochs = [int(e) for e in args.load_epoch.strip().lower().split('-')]
    for best_epoch in test_epochs:
        print("### Epoch {:}...".format(best_epoch))
        sys.stdout.flush()
        ckpt_path = os.path.join(args.save_dir, 'model.e{:}.bin'.format(best_epoch))
        ckpt = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(ckpt)
        model.to(device)

        # eval_res = model_evaluation(model, train_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
        #                             use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, n_gpu=n_gpu,
        #                             is_gt_op=False, is_gt_p_state=False, is_gt_gen=False)
        #
        # print("### Epoch {:} Train Score : ".format(best_epoch), eval_res)
        # print('\n'*2)
        # sys.stdout.flush()

        eval_res = model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                                    use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only, no_dial=args.no_dial, n_gpu=n_gpu,
                                    is_gt_op=False, is_gt_p_state=False, is_gt_gen=False)

        print("### Epoch {:} Test Score : ".format(best_epoch), eval_res)
        print('\n'*2)
        sys.stdout.flush()
示例#8
0
def main(args):
    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available():
        n_gpu = torch.cuda.device_count()
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    if os.path.exists(args.train_data_path + ".pk"):
        train_data_raw = load_data(args.train_data_path + ".pk")
    else:

        train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                         tokenizer=tokenizer,
                                         slot_meta=slot_meta,
                                         n_history=args.n_history,
                                         max_seq_length=args.max_seq_length,
                                         op_code=args.op_code)

    print("# train examples %d" % len(train_data_raw))
    # maxlen = 0
    # for i in range(len(train_data_raw)):
    #     maxlen = max(maxlen, len(train_data_raw[i].turn_utter.split()))
    # print(maxlen)

    if os.path.exists(args.dev_data_path + ".pk"):
        dev_data_raw = load_data(args.dev_data_path + ".pk")
    else:
        dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                       tokenizer=tokenizer,
                                       slot_meta=slot_meta,
                                       n_history=args.n_history,
                                       max_seq_length=args.max_seq_length,
                                       op_code=args.op_code)
    print("# dev examples %d" % len(dev_data_raw))

    if os.path.exists(args.test_data_path + ".pk"):
        test_data_raw = load_data(args.test_data_path + ".pk")
    else:
        test_data_raw = prepare_dataset(data_path=args.test_data_path,
                                        tokenizer=tokenizer,
                                        slot_meta=slot_meta,
                                        n_history=args.n_history,
                                        max_seq_length=args.max_seq_length,
                                        op_code=args.op_code)
    print("# test examples %d" % len(test_data_raw))

    dst = DST(args, ontology, slot_meta)

    best_score = {'epoch': float("-inf"), 'joint_acc_score': float("-inf"), 'op_acc': float("-inf"),
                  'final_slot_f1': float("-inf")}

    dst.model.train_model(train_data_raw, test_data_raw, ontology, slot_meta, args.out_dir, args.filename, show_running_loss=True)
def main(args):

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
        print("### mkdir {:}".format(args.save_dir))

    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available() and args.use_cpu:
        n_gpu = torch.cuda.device_count()
        device = torch.device('cuda')
        print("### Device: {:}".format(device))
    else:
        print("### Use CPU (Debugging)")
        device = torch.device("cpu")

    if args.random_seed < 0:
        print("### Pick a random seed")
        args.random_seed = random.sample(list(range(1, 100000)), 1)[0]

    print("### Random Seed: {:}".format(args.random_seed))
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)

    if n_gpu > 0:
        if args.random_seed >= 0:
            torch.cuda.manual_seed(args.random_seed)
            torch.cuda.manual_seed_all(args.random_seed)

        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]
    print(op2id)

    tokenizer = BertTokenizer.from_pretrained(args.bert_config)

    special_tokens = ['[SLOT]', '[NULL]']
    special_tokens_dict = {'additional_special_tokens': special_tokens}
    tokenizer.add_special_tokens(special_tokens_dict)

    test_path = os.path.join(args.data_root_test, "test.pt")

    if not os.path.exists(test_path):
        test_data_raw = prepare_dataset_for_inference(data_path=args.test_data_path,
                                                      data_list=None,
                                                      tokenizer=tokenizer,
                                                      slot_meta=slot_meta,
                                                      n_history=args.n_history,
                                                      max_seq_length=args.max_seq_length,
                                                      op_code=args.op_code)
        torch.save(test_data_raw, test_path)
    else:
        test_data_raw = torch.load(test_path)

    print("# test examples %d" % len(test_data_raw))

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = 0.
    model_config.attention_probs_dropout_prob = 0.
    model_config.hidden_dropout_prob = 0.

    type_vocab_size = 4
    dec_config = args
    model = TransformerDST(model_config, dec_config, len(op2id), len(domain2id),
                           op2id['update'],
                           tokenizer.convert_tokens_to_ids(['[MASK]'])[0],
                           tokenizer.convert_tokens_to_ids(['[SEP]'])[0],
                           tokenizer.convert_tokens_to_ids(['[PAD]'])[0],
                           tokenizer.convert_tokens_to_ids(['-'])[0],
                           type_vocab_size, args.exclude_domain)


    state_dict = torch.load(args.bert_ckpt_path, map_location='cpu')
    _k = 'bert.embeddings.token_type_embeddings.weight'
    print("config.type_vocab_size != state_dict[bert.embeddings.token_type_embeddings.weight] ({0} != {1})".format(
        type_vocab_size, state_dict[_k].shape[0]))
    state_dict[_k] = state_dict[_k].repeat(int(type_vocab_size / state_dict[_k].shape[0]), 1)
    state_dict[_k].data[2, :].copy_(state_dict[_k].data[0, :])
    state_dict[_k].data[3, :].copy_(state_dict[_k].data[0, :])
    model.bert.load_state_dict(state_dict, strict=False)

    # re-initialize added special tokens ([SLOT], [NULL], [EOS])
    model.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)

    # re-initialize seg-2, seg-3
    model.bert.embeddings.token_type_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    model.bert.embeddings.token_type_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)
    model.bert.resize_token_embeddings(len(tokenizer))

    test_epochs = [int(e) for e in args.load_epoch.strip().lower().split('-')]
    for best_epoch in test_epochs:
        print("### Epoch {:}...".format(best_epoch))
        sys.stdout.flush()
        ckpt_path = os.path.join('/opt/ml/code/transformer_dst', args.save_dir, 'model.e{:}.bin'.format(best_epoch))
        ckpt = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(ckpt)
        model.to(device)

        eval_res = model_evaluation(model, test_data_raw, tokenizer, slot_meta, best_epoch, args.op_code,
                                    use_full_slot=args.use_full_slot, use_dt_only=args.use_dt_only,
                                    no_dial=args.no_dial, n_gpu=n_gpu,
                                    is_gt_op=False, is_gt_p_state=False, is_gt_gen=False, submission=True)

        print("### Epoch {:} Test Score : ".format(best_epoch), eval_res)
        print('\n' * 2)
        sys.stdout.flush()
示例#10
0
def main(args):
    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available():
        n_gpu = torch.cuda.device_count()
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.out_dir):
        os.mkdir(args.out_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    if os.path.exists(args.train_data_path + ".pk"):
        train_data_raw = load_data(args.train_data_path + ".pk")
    else:

        train_data_raw = prepare_dataset(data_path=args.train_data_path,
                                         tokenizer=tokenizer,
                                         slot_meta=slot_meta,
                                         n_history=args.n_history,
                                         max_seq_length=args.max_seq_length,
                                         op_code=args.op_code)

    print("# train examples %d" % len(train_data_raw))

    if os.path.exists(args.dev_data_path + ".pk"):
        dev_data_raw = load_data(args.dev_data_path + ".pk")
    else:
        dev_data_raw = prepare_dataset(data_path=args.dev_data_path,
                                       tokenizer=tokenizer,
                                       slot_meta=slot_meta,
                                       n_history=args.n_history,
                                       max_seq_length=args.max_seq_length,
                                       op_code=args.op_code)
    print("# dev examples %d" % len(dev_data_raw))

    if os.path.exists(args.test_data_path + ".pk"):
        test_data_raw = load_data(args.test_data_path + ".pk")
    else:
        test_data_raw = prepare_dataset(data_path=args.test_data_path,
                                        tokenizer=tokenizer,
                                        slot_meta=slot_meta,
                                        n_history=args.n_history,
                                        max_seq_length=args.max_seq_length,
                                        op_code=args.op_code)
    print("# test examples %d" % len(test_data_raw))

    model = DST_PICK.from_pretrained('bert-base-uncased')

    if not os.path.exists(args.bert_ckpt_path):
        args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path,
                                            args.bert_config_path, 'assets')

    ckpt = torch.load(args.bert_ckpt_path, map_location='cpu')
    model.bert.load_state_dict(ckpt)

    no_decay = ["bias", "LayerNorm.weight", "LayerNorm.bias"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
        {
            "params": [
                p for n, p in model.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters)

    model.to(device)

    # if n_gpu > 1:
    #     model = torch.nn.DataParallel(model)

    best_score = {
        'epoch': float("-inf"),
        'joint_acc_score': float("-inf"),
        'op_acc': float("-inf"),
        'final_slot_f1': float("-inf")
    }

    best_epoch = 0
    for epoch in range(args.n_epochs):
        batch_loss = []
        model.train()
        for step in tqdm(range(len(train_data_raw)), desc="training"):

            inp = model.generate_train_instances(train_data_raw[step],
                                                 ontology, tokenizer, device)

            # ignore dialog with no training slots
            if not inp:
                continue

            loss = model(**inp)
            batch_loss.append(loss.item())

            loss.backward()
            optimizer.step()
            model.zero_grad()

            if step % 100 == 0:
                print("[%d/%d] [%d/%d] mean_loss : %.3f" %
                      (epoch + 1, args.n_epochs, step, len(train_data_raw),
                       np.mean(batch_loss)))
                batch_loss = []

        if (epoch + 1) % args.eval_epoch == 0:
            eval_res, res_per_domain, pred = model.evaluate(
                dev_data_raw, tokenizer, ontology, slot_meta, epoch + 1,
                device)

            #
            if eval_res['joint_acc_score'] > best_score['joint_acc_score']:
                best_score['joint_acc_score'] = eval_res['joint_acc_score']
                model_to_save = model.module if hasattr(model,
                                                        'module') else model
                save_path = os.path.join(args.out_dir, args.filename + '.bin')
                torch.save(model_to_save.state_dict(), save_path)
                best_epoch = epoch + 1
            print("Best Score : ", best_score['joint_acc_score'])
            print("\n")

    print("Test using best model...")
    ckpt_path = os.path.join(args.out_dir, args.filename + '.bin')
    model = DST_PICK.from_pretrained('bert-base-uncased')
    ckpt = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(ckpt)
    model.to(device)

    eval_res, res_per_domain, pred = model.evaluate(dev_data_raw, tokenizer,
                                                    ontology, slot_meta,
                                                    best_epoch, device)
    # save to file
    save_result_to_file(args.out_dir + "/" + args.filename + ".res", eval_res,
                        res_per_domain)
    json.dump(pred, open('%s.pred' % (args.out_dir + "/" + args.filename),
                         'w'))
示例#11
0
def main(args):
    def worker_init_fn(worker_id):
        np.random.seed(args.random_seed + worker_id)

    n_gpu = 0
    if torch.cuda.is_available():
        n_gpu = torch.cuda.device_count()
    np.random.seed(args.random_seed)
    random.seed(args.random_seed)
    rng = random.Random(args.random_seed)
    torch.manual_seed(args.random_seed)
    if n_gpu > 0:
        torch.cuda.manual_seed(args.random_seed)
        torch.cuda.manual_seed_all(args.random_seed)
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)

    ontology = json.load(open(args.ontology_data))
    slot_meta, ontology = make_slot_meta(ontology)
    op2id = OP_SET[args.op_code]
    # print(op2id)
    tokenizer = BertTokenizer.from_pretrained("dsksd/bert-ko-small-minimal")

    out_path = '/opt/ml/code/new-som-dst/pickles'
    if os.path.exists(out_path):
        print("Pickles are exist!")
        with open(out_path + '/train_data_raw.pkl', 'rb') as f:
            train_data_raw = pickle.load(f)
        with open(out_path + '/train_data.pkl', 'rb') as f:
            train_data = pickle.load(f)
        with open(out_path + '/dev_data_raw.pkl', 'rb') as f:
            dev_data_raw = pickle.load(f)
        print("Pickles brought!")
    else:
        print("Pickles are not exist!")
        train_dials, dev_dials = load_dataset(args.train_data_path)
        print(f"t_d_len : {len(train_dials)}, d_d_len : {len(dev_dials)}")
        train_data_raw = prepare_dataset(dials=train_dials,
                                         tokenizer=tokenizer,
                                         slot_meta=slot_meta,
                                         n_history=args.n_history,
                                         max_seq_length=args.max_seq_length,
                                         op_code=args.op_code)
        # print("train_data_raw is ready")
        train_data = WosDataset(train_data_raw, tokenizer, slot_meta,
                                args.max_seq_length, rng, ontology,
                                args.word_dropout, args.shuffle_state,
                                args.shuffle_p)

        dev_data_raw = prepare_dataset(dials=dev_dials,
                                       tokenizer=tokenizer,
                                       slot_meta=slot_meta,
                                       n_history=args.n_history,
                                       max_seq_length=args.max_seq_length,
                                       op_code=args.op_code)
        # print(len(dev_data_raw))
        os.makedirs(out_path, exist_ok=True)
        with open(out_path + '/train_data_raw.pkl', 'wb') as f:
            pickle.dump(train_data_raw, f)
        with open(out_path + '/train_data.pkl', 'wb') as f:
            pickle.dump(train_data, f)
        with open(out_path + '/dev_data_raw.pkl', 'wb') as f:
            pickle.dump(dev_data_raw, f)
        print("Pickles saved!")

    print("# train examples %d" % len(train_data_raw))
    print("# dev examples %d" % len(dev_data_raw))

    # test_data_raw = prepare_dataset(data_path=args.test_data_path,
    #                                 tokenizer=tokenizer,
    #                                 slot_meta=slot_meta,
    #                                 n_history=args.n_history,
    #                                 max_seq_length=args.max_seq_length,
    #                                 op_code=args.op_code)
    # print("# test examples %d" % len(test_data_raw))

    model_config = BertConfig.from_json_file(args.bert_config_path)
    model_config.dropout = args.dropout
    model_config.attention_probs_dropout_prob = args.attention_probs_dropout_prob
    model_config.hidden_dropout_prob = args.hidden_dropout_prob

    model = SomDST(model_config, len(op2id), len(domain2id), op2id['update'],
                   args.exclude_domain)
    ckpt = torch.load('/opt/ml/outputs/model_20.bin', map_location='cpu')
    model.load_state_dict(ckpt)
    print(f"model is loaded!")

    # if not os.path.exists(args.bert_ckpt_path):
    #     args.bert_ckpt_path = download_ckpt(args.bert_ckpt_path, args.bert_config_path, '/opt/ml/code/new-som-dst/assets')

    # ckpt = torch.load(args.bert_ckpt_path, map_location='cpu')
    # model.encoder.bert.load_state_dict(ckpt, strict=False)

    # # re-initialize added special tokens ([SLOT], [NULL], [EOS])
    # model.encoder.bert.embeddings.word_embeddings.weight.data[1].normal_(mean=0.0, std=0.02)
    # model.encoder.bert.embeddings.word_embeddings.weight.data[2].normal_(mean=0.0, std=0.02)
    # model.encoder.bert.embeddings.word_embeddings.weight.data[3].normal_(mean=0.0, std=0.02)
    model.to(device)

    print()

    wandb.watch(model)

    num_train_steps = int(
        len(train_data_raw) / args.batch_size * args.n_epochs)

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    enc_param_optimizer = list(model.encoder.named_parameters())
    enc_optimizer_grouped_parameters = [{
        'params': [
            p for n, p in enc_param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in enc_param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    enc_optimizer = AdamW(enc_optimizer_grouped_parameters, lr=args.enc_lr)
    enc_scheduler = get_linear_schedule_with_warmup(
        enc_optimizer,
        num_warmup_steps=int(num_train_steps * args.enc_warmup),
        num_training_steps=num_train_steps)

    dec_param_optimizer = list(model.decoder.parameters())
    dec_optimizer = AdamW(dec_param_optimizer, lr=args.dec_lr)
    dec_scheduler = get_linear_schedule_with_warmup(
        dec_optimizer,
        num_warmup_steps=int(num_train_steps * args.dec_warmup),
        num_training_steps=num_train_steps)

    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.batch_size,
                                  collate_fn=train_data.collate_fn,
                                  num_workers=args.num_workers,
                                  worker_init_fn=worker_init_fn)

    loss_fnc = nn.CrossEntropyLoss()
    best_score = {'epoch': 0, 'joint_acc': 0, 'op_acc': 0, 'final_slot_f1': 0}
    for epoch in range(args.n_epochs):
        batch_loss = []
        model.train()
        for step, batch in enumerate(train_dataloader):
            batch = [
                b.to(device) if not isinstance(b, int) else b for b in batch
            ]
            input_ids, input_mask, segment_ids, state_position_ids, op_ids,\
            domain_ids, gen_ids, max_value, max_update = batch

            if rng.random() < args.decoder_teacher_forcing:  # teacher forcing
                teacher = gen_ids
            else:
                teacher = None

            domain_scores, state_scores, gen_scores = model(
                input_ids=input_ids,
                token_type_ids=segment_ids,
                state_positions=state_position_ids,
                attention_mask=input_mask,
                max_value=max_value,
                op_ids=op_ids,
                max_update=max_update,
                teacher=teacher)
            # print(f"input_id : {input_ids[0].shape} {input_ids[0]}")
            # print(f"segment_id : {segment_ids[0].shape} {segment_ids[0]}")
            # print(f"slot_position : {state_position_ids[0].shape} {state_position_ids[0]}")
            # print(f"input_mask : {input_mask[0].shape} {input_mask[0]}")
            # print(f"state_scores : {state_scores[0].shape} {state_scores[0]}")
            # print(f"gen_scores : {gen_scores[0].shape} {gen_scores[0]}")

            # print(f"op_ids : {op_ids.shape, op_ids}")
            loss_s = loss_fnc(state_scores.view(-1, len(op2id)),
                              op_ids.view(-1))
            # print("loss_s", loss_s.shape, loss_s)
            loss_g = masked_cross_entropy_for_value(
                gen_scores.contiguous(),  # B, J', K, V
                gen_ids.contiguous(),  # B, J', K
                tokenizer.vocab['[PAD]'])
            # print("loss_g", loss_g)
            # print(f"gen_scores : {gen_scores.shape, torch.argmax(gen_scores[0][0], -1)}")
            # print(f"gen_ids : {gen_ids.shape, gen_ids[0][0], tokenizer.decode(gen_ids[0][0])}")

            loss = loss_s + loss_g
            if args.exclude_domain is not True:
                loss_d = loss_fnc(domain_scores.view(-1, len(domain2id)),
                                  domain_ids.view(-1))
                loss = loss + loss_d
            batch_loss.append(loss.item())

            loss.backward()

            enc_optimizer.step()
            enc_scheduler.step()
            dec_optimizer.step()
            dec_scheduler.step()
            model.zero_grad()

            if (step + 1) % 100 == 0:
                if args.exclude_domain is not True:
                    print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f, dom_loss : %.3f" \
                          % (epoch+1, args.n_epochs, step+1,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g.item(), loss_d.item()))
                else:
                    print("[%d/%d] [%d/%d] mean_loss : %.3f, state_loss : %.3f, gen_loss : %.3f" \
                          % (epoch+1, args.n_epochs, step+1,
                             len(train_dataloader), np.mean(batch_loss),
                             loss_s.item(), loss_g.item()))
                batch_loss = []

        if (epoch + 1) % args.eval_epoch == 0:
            eval_res = model_evaluation(model, dev_data_raw, tokenizer,
                                        slot_meta, epoch + 1, args.op_code)
            if eval_res['joint_acc'] > best_score['joint_acc']:
                best_score = eval_res
                model_to_save = model.module if hasattr(model,
                                                        'module') else model
                save_path = os.path.join(args.save_dir, 'model_best.bin')
                torch.save(model_to_save.state_dict(), save_path)
            print("Best Score : ", best_score)
            print("\n")

        wandb.log({
            'joint_acc': eval_res['joint_acc'],
            'slot_acc': eval_res['slot_acc'],
            'slot_f1': eval_res['slot_f1'],
            'op_acc': eval_res['op_acc'],
            'op_f1': eval_res['op_f1'],
            'final_slot_f1': eval_res['final_slot_f1']
        })

        # save model at 10 epochs
        if (epoch + 1) % 10 == 0:
            model_to_save = model.module if hasattr(model, 'module') else model
            save_path = os.path.join(args.save_dir, f'model_{epoch+1}.bin')
            torch.save(model_to_save.state_dict(), save_path)
            print(f"model_{epoch}.bin is saved!")