Exemplo n.º 1
0
def convert_t5(args):
    logging.info('converting T5 model from Huggingface...')
    if not os.path.exists(args.dest_dir):
        os.mkdir(args.dest_dir)
    converted = {}
    # convert and save vocab
    convert_vocab(args, converted)
    # convert and save config
    gluon_cfg = convert_config(args, converted)
    # convert, (test), and save model
    hf_t5 = HF_T5.from_pretrained(args.model_name)
    gluon_t5 = Gluon_T5.from_cfg(gluon_cfg)
    gluon_t5 = convert_params(args, converted, hf_t5, gluon_t5)
    gluon_t5.hybridize()
    # test model if needed
    if args.test:
        test_conversion(args, hf_t5, gluon_t5)
    # rename with sha1sum
    rename(args, converted)
    logging.info('conversion completed.')
    logging.info('file statistics:')
    for item, new_path in converted.items():
        logging.info('filename: {}\tsize: {}\tsha1sum: {}'.format(
            os.path.basename(new_path), os.path.getsize(new_path),
            sha1sum(new_path)))
    return converted
    def __init__(self, config, x_embed):
        super().__init__()

        self.model = T5Model.from_pretrained(config.pretrained_weights)
        self.encoder_out_size = self.model.config.d_model  # 1024 for t-large

        return
        def create_and_check_t5_model(
            self,
            config,
            encoder_input_ids,
            decoder_input_ids,
            encoder_attention_mask,
            decoder_attention_mask,
            decoder_lm_labels,
        ):
            model = T5Model(config=config)
            model.eval()
            decoder_output, encoder_output = model(
                encoder_input_ids=encoder_input_ids,
                decoder_input_ids=decoder_input_ids,
                encoder_attention_mask=encoder_attention_mask,
                decoder_attention_mask=decoder_attention_mask,
            )
            decoder_output, encoder_output = model(
                encoder_input_ids=encoder_input_ids,
                decoder_input_ids=decoder_input_ids)

            result = {
                "encoder_output": encoder_output,
                "decoder_output": decoder_output,
            }
            self.parent.assertListEqual(
                list(result["encoder_output"].size()),
                [self.batch_size, self.encoder_seq_length, self.hidden_size])
            self.parent.assertListEqual(
                list(result["decoder_output"].size()),
                [self.batch_size, self.decoder_seq_length, self.hidden_size])
Exemplo n.º 4
0
 def test_export_to_onnx(self):
     config_and_inputs = self.model_tester.prepare_config_and_inputs()
     model = T5Model(config_and_inputs[0]).to(torch_device)
     with tempfile.TemporaryDirectory() as tmpdirname:
         torch.onnx.export(
             model, config_and_inputs[1], f"{tmpdirname}/t5_test.onnx", export_params=True, opset_version=9,
         )
Exemplo n.º 5
0
    def __init__(self, dropout=0.5):
        super().__init__()
        drop = nn.Dropout(dropout)

        if use_t5:
            """
            Use t5_model.encoder as the encoder for this model. Note that unlike the custom transformer, you don't
            need to use an external input or positional embedding for the T5 transformer 
            (i.e. don't define self.in_embed or self.pos_emb) since it already defines them internally

            You may specify layer weights to freeze during finetuning by modifying the freeze_layers global variable
            """
            ### Your code here ###
            self.t5_model = T5Model.from_pretrained(f't5-{use_t5}')
            self.t5_encoder = self.t5_model.encoder

            for i_layer, block in enumerate(self.t5_encoder.block):
                if i_layer in freeze_layers:
                    for param in block.parameters():
                        param.requires_grad = False
        else:
            # Input embedding for custom transformer
            self.in_embed = nn.Sequential(nn.Embedding(in_vocab.n, n_hid, padding_idx=in_vocab.pad), drop)
            # Positional embedding for custom transformer
            self.pos_embed = nn.Embedding(1 + n_max_in, n_hid)  # Use the first position as global vector
            self.transformer_layers = nn.ModuleList(TransformerBlock() for _ in range(n_layers))

        self.gcn = GCN(n_head=args.n_head, dropout=args.dropout)

        self.decoder = TreeDecoder()

        if not use_t5:
            self.apply(self.init_weight)
Exemplo n.º 6
0
def get_emb(inputs_list, model_name, max_length=512):
    if 't5' in model_name:
        tokenizer = T5Tokenizer.from_pretrained(TOKEN_DIR)
        model = T5Model.from_pretrained(MODEL_DIR)
        inputs = tokenizer.batch_encode_plus(inputs_list,
                                             max_length=max_length,
                                             pad_to_max_length=True,
                                             return_tensors="pt")
        outputs = model(input_ids=inputs['input_ids'],
                        decoder_input_ids=inputs['input_ids'])
        last_hidden_states = torch.mean(outputs[0], dim=1)
        return last_hidden_states.tolist()

    elif 'bert' in model_name:
        tokenizer = BertTokenizer.from_pretrained(
            'bert-base-multilingual-cased')
        model = TFBertModel.from_pretrained('bert-base-multilingual-cased')
        batch_encoding = tokenizer.batch_encode_plus(
            ["this is", "the second", "the thrid"],
            max_length=max_length,
            pad_to_max_length=True)

        outputs = model(tf.convert_to_tensor(batch_encoding['input_ids'])
                        )  # shape: (batch,sequence length, hidden state)
        embeddings = tf.reduce_mean(outputs[0], 1)
        return embeddings.numpy().tolist()
Exemplo n.º 7
0
    def create_and_check_t5_model(
        self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
    ):
        model = T5Model(config=config)
        model.to(torch_device)
        model.eval()
        decoder_output, decoder_past, encoder_output = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )
        decoder_output, decoder_past, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)

        result = {
            "encoder_output": encoder_output,
            "decoder_output": decoder_output,
            "decoder_past": decoder_past,
        }
        self.parent.assertListEqual(
            list(result["encoder_output"].size()), [self.batch_size, self.encoder_seq_length, self.hidden_size]
        )
        self.parent.assertListEqual(
            list(result["decoder_output"].size()), [self.batch_size, self.decoder_seq_length, self.hidden_size]
        )
        self.parent.assertEqual(len(decoder_past), 2)
        # decoder_past[0] should correspond to encoder output
        self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
        # There should be `num_layers` key value embeddings stored in decoder_past[1]
        self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
        # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
        self.parent.assertEqual(len(decoder_past[1][0]), 4)
Exemplo n.º 8
0
    def create_and_check_decoder_model_past_large_inputs(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = T5Model(config=config).get_decoder().to(torch_device).eval()
        # first forward pass
        outputs = model(input_ids, use_cache=True)

        output, past_key_values = outputs.to_tuple()

        # create hypothetical multiple next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)

        # append to next input_ids and
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)

        output_from_no_past = model(next_input_ids)["last_hidden_state"]
        output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -3:, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, :, random_slice_idx].detach()

        self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
Exemplo n.º 9
0
    def get_model(self) -> Union[T5Model, T5EncoderModel]:

        if not self._decoder:
            if self._half_precision_model:
                model = T5EncoderModel.from_pretrained(
                    self._model_directory, torch_dtype=torch.float16)
            else:
                model = T5EncoderModel.from_pretrained(self._model_directory)
        else:
            if self._half_precision_model:
                model = T5Model.from_pretrained(self._model_directory,
                                                torch_dtype=torch.float16)
            else:
                model = T5Model.from_pretrained(self._model_directory)

        return model
Exemplo n.º 10
0
    def check_prepare_lm_labels_via_shift_left(
        self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
    ):
        model = T5Model(config=config)
        model.to(torch_device)
        model.eval()

        # make sure that lm_labels are correctly padded from the right
        lm_labels.masked_fill_((lm_labels == self.decoder_start_token_id), self.eos_token_id)

        # add casaul pad token mask
        triangular_mask = torch.tril(lm_labels.new_ones(lm_labels.shape)).logical_not()
        lm_labels.masked_fill_(triangular_mask, self.pad_token_id)
        decoder_input_ids = model._shift_right(lm_labels)

        for i, (decoder_input_ids_slice, lm_labels_slice) in enumerate(zip(decoder_input_ids, lm_labels)):
            # first item
            self.parent.assertEqual(decoder_input_ids_slice[0].item(), self.decoder_start_token_id)
            if i < decoder_input_ids_slice.shape[-1]:
                if i < decoder_input_ids.shape[-1] - 1:
                    # items before diagonal
                    self.parent.assertListEqual(
                        decoder_input_ids_slice[1 : i + 1].tolist(), lm_labels_slice[:i].tolist()
                    )
                # pad items after diagonal
                if i < decoder_input_ids.shape[-1] - 2:
                    self.parent.assertListEqual(
                        decoder_input_ids_slice[i + 2 :].tolist(), lm_labels_slice[i + 1 : -1].tolist()
                    )
            else:
                # all items after square
                self.parent.assertListEqual(decoder_input_ids_slice[1:].tolist(), lm_labels_slice[:-1].tolist())
Exemplo n.º 11
0
    def create_and_check_model(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = T5Model(config=config)
        model.to(torch_device)
        model.eval()
        result = model(
            input_ids=input_ids,
            decoder_input_ids=decoder_input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
        )
        result = model(input_ids=input_ids,
                       decoder_input_ids=decoder_input_ids)
        decoder_output = result.last_hidden_state
        decoder_past = result.past_key_values
        encoder_output = result.encoder_last_hidden_state

        self.parent.assertEqual(
            encoder_output.size(),
            (self.batch_size, self.encoder_seq_length, self.hidden_size))
        self.parent.assertEqual(
            decoder_output.size(),
            (self.batch_size, self.decoder_seq_length, self.hidden_size))
        # There should be `num_layers` key value embeddings stored in decoder_past
        self.parent.assertEqual(len(decoder_past), config.num_layers)
        # There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
        self.parent.assertEqual(len(decoder_past[0]), 4)
Exemplo n.º 12
0
    def create_and_check_t5_decoder_model_past(
        self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
    ):
        model = T5Model(config=config).get_decoder()
        model.to(torch_device)
        model.eval()

        # first forward pass
        outputs = model(input_ids, use_cache=True)
        outputs_use_cache_conf = model(input_ids)
        outputs_no_past = model(input_ids, use_cache=False)

        self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
        self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)

        output, past_key_value_states = outputs

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # append to next input_ids and
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)

        output_from_no_past = model(next_input_ids)[0]
        output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
Exemplo n.º 13
0
    def rank(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
    ) -> Any:
        t5_outputs = T5Model.forward(
            self,
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=input_ids,
            decoder_attention_mask=attention_mask,
        )
        next_token_logits = t5_outputs[0][:, -1, :]
        logits = self.dropout(next_token_logits)
        logits = self.classifier(logits)

        outputs = (logits,)

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()  # type:ignore
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs  # type:ignore

        return outputs
Exemplo n.º 14
0
    def __init__(self,
                 model_name_or_path: str,
                 max_seq_length: int = 128,
                 do_lower_case: Optional[bool] = None,
                 task_identifier: str = 'stsb sentence1: ',
                 model_args: Dict = {},
                 tokenizer_args: Dict = {}):
        super(T5, self).__init__()
        self.config_keys = [
            'max_seq_length', 'do_lower_case', 'task_identifier'
        ]
        self.do_lower_case = do_lower_case

        if max_seq_length > 512:
            logging.warning(
                "T5 only allows a max_seq_length of 512. Value will be set to 512"
            )
            max_seq_length = 512
        self.max_seq_length = max_seq_length

        if self.do_lower_case is not None:
            tokenizer_args['do_lower_case'] = do_lower_case

        self.t5model = T5Model.from_pretrained(model_name_or_path,
                                               **model_args)
        self.tokenizer = T5Tokenizer.from_pretrained(model_name_or_path,
                                                     **tokenizer_args)
        self.task_identifier = task_identifier
Exemplo n.º 15
0
 def prepare_model(self, condition_generation=False):
     if condition_generation:
         self.model = T5ForConditionalGeneration.from_pretrained('t5-base')
     else:
         t5_model = T5Model.from_pretrained('t5-base')
         self.model = GenerationModel(t5_model)
     self.load_checkpoint()
     self.model = self.model.cuda()
Exemplo n.º 16
0
    def __init__(self):

        super().__init__()

        self.t5 = t5 = T5Model.from_pretrained('t5-small')

        self.out = nn.Linear(t5.config.to_dict()['d_model'],
                             t5.config.to_dict()['vocab_size'])
Exemplo n.º 17
0
 def create_and_check_t5_model_fp16_forward(
     self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
 ):
     model = T5Model(config=config)
     model.to(torch_device)
     model.half()
     model.eval()
     output = model(input_ids, decoder_input_ids=input_ids, attention_mask=attention_mask)[0]
     self.parent.assertFalse(torch.isnan(output).any().item())
Exemplo n.º 18
0
 def get_model(self) -> Union[T5Model, T5EncoderModel]:
     if not self._decoder:
         model = T5EncoderModel.from_pretrained(self._model_directory)
     else:
         model = T5Model.from_pretrained(self._model_directory)
     # Compute in half precision, saving us half the memory
     if self._half_precision_model:
         model = model.half()
     return model
Exemplo n.º 19
0
 def __init__(self, model, num_steps, num_classes=2):
     super(T5Classifier, self).__init__()
     hidden_size = {
         "t5-small": 512,
         "t5-base": 768,
         "t5-large": 1024,
     }[model]
     self.model = T5Model.from_pretrained(model)
     self.tokenizer = T5Tokenizer.from_pretrained(model)
     self.num_steps = num_steps
     self.classifier = nn.Linear(hidden_size, num_classes)
Exemplo n.º 20
0
def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = T5Config.from_json_file(config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = T5Model(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_t5(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)
Exemplo n.º 21
0
def main():
    tokenizer = T5Tokenizer.from_pretrained('t5-small')
    model = T5Model.from_pretrained('t5-small')

    input_ids = tokenizer.encode("translate English to German: That is good.",
                                 return_tensors="pt")
    outputs = model(input_ids=input_ids)
    scores = outputs[0]

    out_indices = torch.argmax(scores, dim=2)
    predicted_token = tokenizer.convert_ids_to_tokens(out_indices[0])
    print(predicted_token)
Exemplo n.º 22
0
    def create_and_check_decoder_model_attention_mask_past(
        self,
        config,
        input_ids,
        decoder_input_ids,
        attention_mask,
        decoder_attention_mask,
        lm_labels,
    ):
        model = T5Model(config=config).get_decoder()
        model.to(torch_device)
        model.eval()

        # create attention mask
        attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)

        half_seq_length = input_ids.shape[-1] // 2
        attn_mask[:, half_seq_length:] = 0

        # first forward pass
        output, past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True).to_tuple()

        # create hypothetical next token and extent to next_input_ids
        next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)

        # change a random masked slice from input_ids
        random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
        random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
        input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens

        # append to next input_ids and attn_mask
        next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
        attn_mask = torch.cat(
            [attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
            dim=1,
        )

        # get two different outputs
        output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
        output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
            "last_hidden_state"
        ]

        # select random slice
        random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
        output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
        output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()

        # test that outputs are equal for slice
        self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
Exemplo n.º 23
0
 def prepare_model(self,
                   condition_generation=False,
                   template_decoding=False):
     print('condition_generation: ', condition_generation)
     if condition_generation:
         self.model = T5ForConditionalGeneration.from_pretrained('t5-base')
     else:
         t5_model = T5Model.from_pretrained('t5-base')
         if template_decoding:
             self.model = GenerationModel(t5_model, self.temp)
         else:
             self.model = GenerationModel(t5_model)
     self.lr = 1e-3
     self.model = self.model.cuda()
Exemplo n.º 24
0
    def __init__(self, config, x_embed):
        super().__init__()

        # pretrained_weights = "xlnet-base-cased"
        self.model = T5Model.from_pretrained(config.pretrained_weights)
        

        # if config.use_gpu:
        #   self.model = self.model.to(device=torch.device("cuda"))
        # if config.use_parallel:
        #   self.model = torch.nn.DataParallel(self.model)

        self.encoder_out_size = self.model.config.d_model  # 1024 for t-large

        return
Exemplo n.º 25
0
    def __init__(self,
                 model_name_or_path: str,
                 max_seq_length: int = 128,
                 do_lower_case: bool = True):
        super(T5, self).__init__()
        self.config_keys = ['max_seq_length', 'do_lower_case']
        self.do_lower_case = do_lower_case

        if max_seq_length > 512:
            logging.warning(
                "T5 only allows a max_seq_length of 512. Value will be set to 512"
            )
            max_seq_length = 512
        self.max_seq_length = max_seq_length

        self.enc_model = T5Model.from_pretrained(model_name_or_path)
        self.tokenizer = T5Tokenizer.from_pretrained(
            model_name_or_path, do_lower_case=do_lower_case)
Exemplo n.º 26
0
    def __init__(self, **kwargs):
        """
        Initialize T5 embedder.

        :param model_directory:
        """
        super().__init__(**kwargs)

        self._model_directory = self._options["model_directory"]
        # Until we know whether we need the decoder, let's keep it here as an undocumented option.
        # Should the need arise we can just split this class in to an encoder and a decoder subclass
        # by setting one subclass to _decoder=True and the other to _decoder=False
        self._decoder = self._options.get("decoder", False)

        # make model
        self._model = T5Model.from_pretrained(self._model_directory)
        self._model = self._model.eval().to(self._device)
        self._model_fallback = None
        self._tokenizer = T5Tokenizer(
            str(Path(self._model_directory).joinpath("spiece.model")),
            do_lower_case=False,
        )
Exemplo n.º 27
0
# CLS token will work as BOS token
# tokenizer.bos_token = tokenizer.cls_token
# SEP token will work as EOS token
# tokenizer.eos_token = tokenizer.sep_token

# load dataset
dataset = Task71Dataset("train", tokenizer=tokenizer)

collator_fn = Task71aCollatorFeatures(device='cpu')
loader = DataLoader(dataset, batch_size=options.batch_size,
                    drop_last=False, shuffle=True,
                    collate_fn=collator_fn)


# create model
encoder = T5Model.from_pretrained('t5-base')

# change config if you want
# encoder.config.output_hidden_states = True
model = T5ClassificationHead(encoder.encoder, encoder.config.hidden_size,
                               num_classes=2, drop=0.2)
if options.modelckpt is not None:
    state_dict = torch.load(options.modelckpt,map_location='cpu')
    model.load_state_dict(state_dict)

model.to(DEVICE)

res_dict = get_features(loader, model, DEVICE)
if not os.path.exists('./features_train/'):
    os.makedirs('./features_train')
pickle.dump(res_dict, open("./features_train/t5_features.pkl", "wb"))
Exemplo n.º 28
0
 def test_model_from_pretrained(self):
     for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
         model = T5Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
         self.assertIsNotNone(model)
Exemplo n.º 29
0
 def test_model_from_pretrained(self):
     for model_name in T5_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = T5Model.from_pretrained(model_name)
         self.assertIsNotNone(model)
Exemplo n.º 30
0
options = parser.parse_args()

# make transforms using only bert tokenizer!
tokenizer = T5Tokenizer.from_pretrained('t5-base')

# load dataset
test_dataset = Task71Dataset("dev", tokenizer=tokenizer)

collator_fn = Task71aCollatorTest(device='cpu')
test_loader = DataLoader(test_dataset,
                         batch_size=options.batch_size,
                         drop_last=False,
                         shuffle=True,
                         collate_fn=collator_fn)

# create model
model = T5Model.from_pretrained('t5-base')
model = T5ClassificationHead(model.encoder,
                             model.config.hidden_size,
                             num_classes=2,
                             drop=0.2,
                             act='none')

if options.modelckpt is not None:
    state_dict = torch.load(options.modelckpt, map_location='cpu')
    model.load_state_dict(state_dict)

model.to(DEVICE)

create_submition_file(options.outfolder, model, test_loader, DEVICE)