Example #1
0
def test(epoch, net, test_loader):
    # global best_acc
    net.eval()
    test_loss = Average()
    test_acc = Accuracy()

    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            outputs = net(inputs)

            loss = F.cross_entropy(outputs, targets)

            test_loss.update(loss.item(), inputs.size(0))
            test_acc.update(outputs, targets)

    return test_loss, test_acc
Example #2
0
def main(config, difficulty,type):
    logger = config.get_logger('train')

    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')
    batch_size = config['data_loader']['args']['batch_size']
    tgt_preprocessing = None
    src_preprocessing = None
    train_loader = Dataloader(
        device=device, difficulty=difficulty,type=type,
        src_preprocessing=src_preprocessing, tgt_preprocessing=tgt_preprocessing,
        batch_size=batch_size)

    valid_loader = Dataloader(
        device=device, difficulty=difficulty,type=type,
        src_preprocessing=src_preprocessing, tgt_preprocessing=tgt_preprocessing,
        batch_size=batch_size, train=False)
    
    model_args = config['arch']['args']
    model_args.update({
        'src_vocab': train_loader.src_vocab,
        'tgt_vocab': train_loader.tgt_vocab,
        'sos_tok': SOS_TOK,
        'eos_tok': EOS_TOK,
        'pad_tok': PAD_TOK,
        'device': device
    })
    model = getattr(models, config['arch']['type'])(**model_args)
    weight = torch.ones(len(train_loader.tgt_vocab))
    criterion = AvgPerplexity(
        ignore_idx=train_loader.tgt_vocab.stoi[PAD_TOK],
        weight=weight)

    criterion.to(device)

    optimizer = get_optimizer(
        optimizer_params=filter(
            lambda p: p.requires_grad, model.parameters()),
        args_dict=config['optimizer'])

    metrics_ftns = [Accuracy(
        train_loader.tgt_vocab.stoi[PAD_TOK])]
    # for param in model.parameters():
    #     param.data.uniform_(-0.08, 0.08)
    trainer = Trainer(
        model=model,
        criterion=criterion,
        metric_ftns=metrics_ftns,
        optimizer=optimizer,
        config=config,
        data_loader=train_loader,
        valid_data_loader=valid_loader,
        log_step=1, len_epoch=200
    )
    trainer.train()
Example #3
0
 def __init__(self, model, device='cpu', verbose=True):
     loss = nn.CrossEntropyLoss()
     loss.__name__ = 'cross_entropy'
     super().__init__(
         model=model,
         loss=loss,
         metrics=[Accuracy()],
         stage_name='valid',
         device=device,
         verbose=verbose,
     )
Example #4
0
def validate(dataiter, model, word_dict, sememe_dict, K=32):
    model.eval()
    with torch.no_grad():
        # pred_labels = []
        # target_labels = []
        dev_metric = Accuracy()

        for i, batch in enumerate(dataiter):
            word = batch["word"].to(device)
            sememes = batch["sememes"].to(device)
            dev_metric += model.compute_metric(word, sememes, K)
    return dev_metric
Example #5
0
 def compute_metric(self, word, sememes, k):
     bsize = sememes.size()[0]
     scores = self.forward(word)
     sorted_scores, sorted_indices = scores.sort(dim=-1, descending=True)
     correct = 0.0
     total = 0.0
     for index, sememe in zip(sorted_indices, sememes):
         total += (sememe != self.pad_idx).long().sum()
         index = set(index.tolist()[:k])
         sememe = set(sememe.tolist())
         correct += len(index.intersection(sememe))
     # return Accuracy(k * bsize, correct)
     return Accuracy(total.item(), correct)
Example #6
0
    def one_batch(self, is_dist=dist_is_initialized()):
        batch_ins, batch_label = self.get_batch()
        batch_grad = torch.zeros([self.n_input, 1],
                                 dtype=torch.float32,
                                 requires_grad=False)

        train_loss = Average()
        train_acc = Accuracy()

        z = self.batch_forward(batch_ins)
        h = self.batch_sigmoid(z)
        batch_label = torch.tensor(batch_label).float()
        loss = self.batch_loss(h, batch_label)
        train_loss.update(torch.mean(loss).item(), self.batch_size)
        train_acc.batch_update(h, batch_label)
        grad_list = self.batch_backward(batch_ins, h, batch_label)
        for g in grad_list:
            batch_grad.add_(g)
        batch_grad = batch_grad.div(self.batch_size)
        batch_grad.mul_(-1.0 * self.lr)
        self.weight.add_(batch_grad)

        # for i in range(len(batch_ins)):
        #     z = self.forward(batch_ins[i])
        #     h = self.sigmoid(z)
        #     loss = self.loss(h, batch_label[i])
        #     # print("z= {}, h= {}, loss = {}".format(z, h, loss))
        #     train_loss.update(loss.data, 1)
        #     train_acc.update(h, batch_label[i])
        #     g = self.backward(batch_ins[i], h.item(), batch_label[i])
        #     batch_grad.add_(g)
        #     batch_bias += np.sum(h.item() - batch_label[i])
        # batch_grad = batch_grad.div(self.batch_size)
        # batch_grad.mul_(-1.0 * self.lr)
        # self.weight.add_(batch_grad)
        # batch_bias = batch_bias / (len(batch_ins))
        # self.bias = self.bias - batch_bias * self.lr
        return train_loss, train_acc
Example #7
0
def train_one_epoch(epoch, net, train_loader, optimizer, worker_index,
                    communicator, optim, sync_mode):
    assert isinstance(communicator, S3Communicator)
    net.train()

    epoch_start = time.time()

    epoch_cal_time = 0
    epoch_comm_time = 0

    train_acc = Accuracy()
    train_loss = Average()

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        batch_start = time.time()
        outputs = net(inputs)
        loss = F.cross_entropy(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        batch_cal_time = time.time() - batch_start
        batch_comm_time = 0

        if optim == "grad_avg":
            if sync_mode == "reduce" or sync_mode == "reduce_scatter":
                grads = [param.grad.data.numpy() for param in net.parameters()]
                batch_cal_time = time.time() - batch_start
                epoch_cal_time += batch_cal_time

                batch_comm_start = time.time()
                postfix = "{}_{}".format(epoch, batch_idx)
                if sync_mode == "reduce":
                    merged_grads = communicator.reduce_batch_nn(pickle.dumps(grads), postfix)
                elif sync_mode == "reduce_scatter":
                    merged_grads = communicator.reduce_batch_nn(pickle.dumps(grads), postfix)

                for layer_index, param in enumerate(net.parameters()):
                    param.grad.data = torch.from_numpy(merged_grads[layer_index])

                batch_comm_time = time.time() - batch_comm_start
                print("one {} round cost {} s".format(sync_mode, batch_comm_time))
                epoch_comm_time += batch_comm_time
            elif sync_mode == "async":
                # async does step before sync
                optimizer.step()
                batch_cal_time = time.time() - batch_start
                epoch_cal_time += batch_cal_time

                batch_comm_start = time.time()
                weights = [param.data.numpy() for param in net.parameters()]
                new_weights = communicator.async_reduce_nn(pickle.dumps(weights), Prefix.w_b_prefix)

                for layer_index, param in enumerate(net.parameters()):
                    param.data = torch.from_numpy(new_weights[layer_index])

                batch_comm_time = time.time() - batch_comm_start
                print("one {} round cost {} s".format(sync_mode, batch_comm_time))
                epoch_comm_time += batch_comm_time

        # async does step before sync
        if sync_mode != "async":
            step_start = time.time()
            optimizer.step()
            batch_cal_time += time.time() - step_start
            epoch_cal_time += batch_cal_time

        train_acc.update(outputs, targets)
        train_loss.update(loss.item(), inputs.size(0))

        if batch_idx % 10 == 0:
            print("Epoch: [{}], Batch: [{}], train loss: {}, train acc: {}, batch cost {} s, "
                  "cal cost {} s, comm cost {} s"
                  .format(epoch + 1, batch_idx + 1, train_loss, train_acc, time.time() - batch_start,
                          batch_cal_time, batch_comm_time))

    if optim == "model_avg":
        weights = [param.data.numpy() for param in net.parameters()]
        epoch_cal_time += time.time() - epoch_start

        epoch_sync_start = time.time()
        postfix = str(epoch)

        if sync_mode == "reduce":
            merged_weights = communicator.reduce_epoch_nn(pickle.dumps(weights), postfix)
        elif sync_mode == "reduce_scatter":
            merged_weights = communicator.reduce_epoch_nn(pickle.dumps(weights), postfix)
        elif sync_mode == "async":
            merged_weights = communicator.async_reduce_nn(pickle.dumps(weights), Prefix.w_b_prefix)

        for layer_index, param in enumerate(net.parameters()):
            param.data = torch.from_numpy(merged_weights[layer_index])

        print("one {} round cost {} s".format(sync_mode, time.time() - epoch_sync_start))
        epoch_comm_time += time.time() - epoch_sync_start

    if worker_index == 0:
        delete_start = time.time()
        # model avg delete by epoch
        if optim == "model_avg" and sync_mode != "async":
            communicator.delete_expired_epoch(epoch)
        elif optim == "grad_avg" and sync_mode != "async":
            communicator.delete_expired_batch(epoch, batch_idx)
        epoch_comm_time += time.time() - delete_start

    print("Epoch {} has {} batches, cost {} s, cal time = {} s, comm time = {} s"
          .format(epoch + 1, batch_idx, time.time() - epoch_start, epoch_cal_time, epoch_comm_time))

    return train_loss, train_acc
Example #8
0
def main(args):

    ########### Build Dictionary  ###############
    sememe_dict = build_sememe_dict(args.data_path)
    print("Total Sememes : {}".format(len(sememe_dict)))
    args.latent_M = len(sememe_dict)
    args.sememe_dict = sememe_dict

    word_dict = build_word_dict(args.embed_path)
    print("Total Words : {}".format(len(word_dict)))
    ############## Build Dataset ####################
    trainset = SememePredictorDataset(word_dict, sememe_dict, "train", args)
    validset = SememePredictorDataset(word_dict, sememe_dict, "valid", args)
    testset = SememePredictorDataset(word_dict, sememe_dict, "test", args)

    print("| Train: {}".format(len(trainset)))
    print("| Valid: {}".format(len(validset)))
    print("| Test: {}".format(len(testset)))

    train_iter = DataLoader(
        trainset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=12,
        collate_fn=SememePredictorDataset.collate_fn,
    )
    valid_iter = DataLoader(
        validset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=12,
        collate_fn=SememePredictorDataset.collate_fn,
    )

    ############### Load Pretrained Word Embedding ##########
    # args.sememe_embedding = load_pretrained_embedding_from_file(args.sememe_embed_path,sememe_dict,args.embed_dim).to(device)
    args.pretrained_word_embedding, embed_dict = load_pretrained_embedding_from_file(
        args.embed_path, word_dict, args.embed_dim)
    args.pretrained_word_embedding.require_grad = False
    # args.sememe_embedding.require_grad = False
    sememe_encoder = SememeEncoder(sememe_dict, word_dict, args).to(device)
    sememe_embedding = sememe_encoder.get_all_sememe_repre().squeeze().detach()
    args.sememe_embedding = sememe_embedding
    ################# Build Model ##############
    model = SememePredictModel(args).to(device)
    # model = SememeFactorizationModel(args,sememe_dict.count).to(device)
    print("Model Built!")
    if args.mode == "analysis":
        model.load_state_dict(torch.load(args.save_path))
        analyze_sememe(
            valid_iter,
            model,
            word_dict,
            sememe_dict,
            embed_dict,
            "sememe_nearest_neighbour.txt",
        )
        analyze_word(
            valid_iter,
            model,
            word_dict,
            sememe_dict,
            embed_dict,
            "word_nearest_neighbour.txt",
        )
        exit()
    ############ Training #################
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                              mode="max",
                                                              factor=0.5,
                                                              patience=5)
    patience = 0
    best_metric = Accuracy()
    ################ Start Training ###########
    for epoch in range(1, args.epoch + 1):
        recall_at_8 = Accuracy()
        recall_at_16 = Accuracy()
        recall_at_32 = Accuracy()
        recall_at_64 = Accuracy()
        recall_at_128 = Accuracy()
        pbar = tqdm(train_iter, dynamic_ncols=True)
        pbar.set_description("[Epoch {}, Best Metric {}]".format(
            epoch, best_metric))
        model.train()
        for batch in pbar:
            word = batch["word"].to(device)
            sememes = batch["sememes"].to(device)
            # logits = model(word)
            loss = model.compute_loss(word, sememes)
            recall_at_8 += model.compute_metric(word, sememes, k=8)
            recall_at_16 += model.compute_metric(word, sememes, k=16)
            recall_at_32 += model.compute_metric(word, sememes, k=32)
            recall_at_64 += model.compute_metric(word, sememes, k=64)
            recall_at_128 += model.compute_metric(word, sememes, k=128)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix(
                loss=loss.item(),
                r_at_8=recall_at_8,
                r_at_16=recall_at_16,
                r_at_32=recall_at_32,
                r_at_64=recall_at_64,
                r_at_128=recall_at_128,
            )

        dev_metric = validate(valid_iter, model, word_dict, sememe_dict, 32)
        lr_scheduler.step(dev_metric.precision())
        if dev_metric > best_metric:
            best_metric = dev_metric
            print("New Best Metric: {}".format(dev_metric))
            print(' R@8: {}, R@16: {}, R@32: {}, R@64: {}, R@128:{}'.format(
                validate(valid_iter, model, word_dict, sememe_dict, 8),
                validate(valid_iter, model, word_dict, sememe_dict, 16),
                validate(valid_iter, model, word_dict, sememe_dict, 32),
                validate(valid_iter, model, word_dict, sememe_dict, 64),
                validate(valid_iter, model, word_dict, sememe_dict, 128),
            ))
            torch.save(model.state_dict(), args.save_path)
            with open(args.save_path + '.sememe_vector', 'w') as f:
                for sym, vec in zip(sememe_dict.symbols, model.fc.weight.data):
                    f.write(sym + ' ' + ' '.join(map(str, vec.tolist())) +
                            '\n')
            patience = 0
        else:
            patience += 1
        if patience > args.patience:
            break

    model.load_state_dict(torch.load(args.save_path))
    test_iter = DataLoader(
        testset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=12,
        collate_fn=SememePredictorDataset.collate_fn,
    )
    test_metric = validate(test_iter, model, word_dict, sememe_dict)
    print("Test FScore: {}".format(test_metric))
Example #9
0
def handler(event, context):
    start_time = time.time()

    # dataset setting
    train_file = event['train_file']
    test_file = event['test_file']
    data_bucket = event['data_bucket']
    n_features = event['n_features']
    n_classes = event['n_classes']
    n_workers = event['n_workers']
    worker_index = event['worker_index']
    cp_bucket = event['cp_bucket']

    # ps setting
    host = event['host']
    port = event['port']

    # training setting
    model_name = event['model']
    optim = event['optim']
    sync_mode = event['sync_mode']
    assert model_name.lower() in MLModel.Deep_Models
    assert optim.lower() in Optimization.Grad_Avg
    assert sync_mode.lower() in Synchronization.Reduce

    # hyper-parameter
    learning_rate = event['lr']
    batch_size = event['batch_size']
    n_epochs = event['n_epochs']
    start_epoch = event['start_epoch']
    run_epochs = event['run_epochs']

    function_name = event['function_name']

    print('data bucket = {}'.format(data_bucket))
    print("train file = {}".format(train_file))
    print("test file = {}".format(test_file))
    print('number of workers = {}'.format(n_workers))
    print('worker index = {}'.format(worker_index))
    print('model = {}'.format(model_name))
    print('optimization = {}'.format(optim))
    print('sync mode = {}'.format(sync_mode))
    print('start epoch = {}'.format(start_epoch))
    print('run epochs = {}'.format(run_epochs))
    print('host = {}'.format(host))
    print('port = {}'.format(port))

    print("Run function {}, round: {}/{}, epoch: {}/{} to {}/{}".format(
        function_name,
        int(start_epoch / run_epochs) + 1, math.ceil(n_epochs / run_epochs),
        start_epoch + 1, n_epochs, start_epoch + run_epochs, n_epochs))

    # download file from s3
    storage = S3Storage()
    local_dir = "/tmp"
    read_start = time.time()
    storage.download(data_bucket, train_file,
                     os.path.join(local_dir, train_file))
    storage.download(data_bucket, test_file,
                     os.path.join(local_dir, test_file))
    print("download file from s3 cost {} s".format(time.time() - read_start))

    train_set = torch.load(os.path.join(local_dir, train_file))
    test_set = torch.load(os.path.join(local_dir, test_file))
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=batch_size,
                                               shuffle=True)
    n_train_batch = len(train_loader)
    test_loader = torch.utils.data.DataLoader(test_set,
                                              batch_size=100,
                                              shuffle=False)
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse',
               'ship', 'truck')

    print("read data cost {} s".format(time.time() - read_start))

    random_seed = 100
    torch.manual_seed(random_seed)

    device = 'cpu'
    model = deep_models.get_models(model_name).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

    # load checkpoint model if it is not the first round
    if start_epoch != 0:
        checked_file = 'checkpoint_{}.pt'.format(start_epoch - 1)
        storage.download(cp_bucket, checked_file,
                         os.path.join(local_dir, checked_file))
        checkpoint_model = torch.load(os.path.join(local_dir, checked_file))

        model.load_state_dict(checkpoint_model['model_state_dict'])
        optimizer.load_state_dict(checkpoint_model['optimizer_state_dict'])
        print("load checkpoint model at epoch {}".format(start_epoch - 1))

    # Set thrift connection
    # Make socket
    transport = TSocket.TSocket(host, port)
    # Buffering is critical. Raw sockets are very slow
    transport = TTransport.TBufferedTransport(transport)
    # Wrap in a protocol
    protocol = TBinaryProtocol.TBinaryProtocol(transport)
    # Create a client to use the protocol encoder
    t_client = ParameterServer.Client(protocol)
    # Connect!
    transport.open()
    # test thrift connection
    ps_client.ping(t_client)
    print("create and ping thrift server >>> HOST = {}, PORT = {}".format(
        host, port))

    # register model
    parameter_shape = []
    parameter_length = []
    model_length = 0
    for param in model.parameters():
        tmp_shape = 1
        parameter_shape.append(param.data.numpy().shape)
        for w in param.data.numpy().shape:
            tmp_shape *= w
        parameter_length.append(tmp_shape)
        model_length += tmp_shape

    ps_client.register_model(t_client, worker_index, model_name, model_length,
                             n_workers)
    ps_client.exist_model(t_client, model_name)
    print("register and check model >>> name = {}, length = {}".format(
        model_name, model_length))

    # Training the Model
    train_start = time.time()
    iter_counter = 0
    for epoch in range(start_epoch, min(start_epoch + run_epochs, n_epochs)):

        model.train()
        epoch_start = time.time()

        train_acc = Accuracy()
        train_loss = Average()

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            batch_start = time.time()
            batch_cal_time = 0
            batch_comm_time = 0

            # pull latest model
            ps_client.can_pull(t_client, model_name, iter_counter,
                               worker_index)
            latest_model = ps_client.pull_model(t_client, model_name,
                                                iter_counter, worker_index)
            pos = 0
            for layer_index, param in enumerate(model.parameters()):
                param.data = Variable(
                    torch.from_numpy(
                        np.asarray(latest_model[pos:pos +
                                                parameter_length[layer_index]],
                                   dtype=np.float32).reshape(
                                       parameter_shape[layer_index])))
                pos += parameter_length[layer_index]
            batch_comm_time += time.time() - batch_start

            batch_cal_start = time.time()
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, targets)
            optimizer.zero_grad()
            loss.backward()

            # flatten and concat gradients of weight and bias
            param_grad = np.zeros((1))
            for param in model.parameters():
                # print("shape of layer = {}".format(param.data.numpy().flatten().shape))
                param_grad = np.concatenate(
                    (param_grad, param.data.numpy().flatten()))
            param_grad = np.delete(param_grad, 0)
            #print("model_length = {}".format(param_grad.shape))
            batch_cal_time += time.time() - batch_cal_start

            # push gradient to PS
            batch_push_start = time.time()
            ps_client.can_push(t_client, model_name, iter_counter,
                               worker_index)
            ps_client.push_grad(t_client, model_name, param_grad,
                                -1. * learning_rate / n_workers, iter_counter,
                                worker_index)
            ps_client.can_pull(t_client, model_name, iter_counter + 1,
                               worker_index)  # sync all workers
            batch_comm_time += time.time() - batch_push_start

            train_acc.update(outputs, targets)
            train_loss.update(loss.item(), inputs.size(0))

            optimizer.step()
            iter_counter += 1

            if batch_idx % 10 == 0:
                print(
                    'Epoch: [%d/%d], Batch: [%d/%d], Time: %.4f, Loss: %.4f, epoch cost %.4f, '
                    'batch cost %.4f s: cal cost %.4f s and communication cost %.4f s'
                    % (epoch + 1, n_epochs, batch_idx + 1, n_train_batch,
                       time.time() - train_start, loss.item(),
                       time.time() - epoch_start, time.time() - batch_start,
                       batch_cal_time, batch_comm_time))

        test_loss, test_acc = test(epoch, model, test_loader)

        print(
            'Epoch: {}/{},'.format(epoch + 1, n_epochs),
            'train loss: {},'.format(train_loss),
            'train acc: {},'.format(train_acc),
            'test loss: {},'.format(test_loss),
            'test acc: {}.'.format(test_acc),
        )

    # training is not finished yet, invoke next round
    if epoch < n_epochs - 1:
        checkpoint_model = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': train_loss.average
        }

        checked_file = 'checkpoint_{}.pt'.format(epoch)

        if worker_index == 0:
            torch.save(checkpoint_model, os.path.join(local_dir, checked_file))
            storage.upload(cp_bucket, checked_file,
                           os.path.join(local_dir, checked_file))
            print("checkpoint model at epoch {} saved!".format(epoch))

        print(
            "Invoking the next round of functions. round: {}/{}, start epoch: {}, run epoch: {}"
            .format(
                int((epoch + 1) / run_epochs) + 1,
                math.ceil(n_epochs / run_epochs), epoch + 1, run_epochs))
        lambda_client = boto3.client('lambda')
        payload = {
            'train_file': event['train_file'],
            'test_file': event['test_file'],
            'data_bucket': event['data_bucket'],
            'n_features': event['n_features'],
            'n_classes': event['n_classes'],
            'n_workers': event['n_workers'],
            'worker_index': event['worker_index'],
            'cp_bucket': event['cp_bucket'],
            'host': event['host'],
            'port': event['port'],
            'model': event['model'],
            'optim': event['optim'],
            'sync_mode': event['sync_mode'],
            'lr': event['lr'],
            'batch_size': event['batch_size'],
            'n_epochs': event['n_epochs'],
            'start_epoch': epoch + 1,
            'run_epochs': event['run_epochs'],
            'function_name': event['function_name']
        }
        lambda_client.invoke(FunctionName=function_name,
                             InvocationType='Event',
                             Payload=json.dumps(payload))

    end_time = time.time()
    print("Elapsed time = {} s".format(end_time - start_time))
        'output_folder':
        'C:\\Users\\thanhdh6\\Documents\\projects\\vinbrain_internship\\image_classifier\\train\\logs',
        'device': 'cpu',
        'gpu_id': 0,
        'lr_schedule': None,
        'config_files':
        'C:\\Users\\thanhdh6\\Documents\\projects\\vinbrain_internship\\image_classifier\\configs\\cifar_configs.py',
        'loss_file': "loss_file.txt",
        'n_crops': 5
    }

    net = TransferNet(model_base=resnet18,
                      pretrain=True,
                      fc_channels=[2048],
                      num_classes=2)
    optimizer_config = {'momentum': 0.9}
    optimizer = SGD(net.parameters(), lr=1e-4, momentum=0.9)
    metric = Accuracy(threshold=0.5, from_logits=True)
    crition = nn.CrossEntropyLoss()
    trainer = Trainer(net,
                      datahandler,
                      optimizer,
                      crition,
                      transform_test,
                      metric,
                      lr_scheduler=OneCycleLR,
                      configs=trainer_configs)
    trainer.train()
    # trainer.load_checkpoint("C:\\Users\\thanhdh6\\Documents\\projects\\vinbrain_internship\\image_classifier\\train\\logs\\checkpoint_9")
    # print(trainer.evaluate(mode="val",metric=Accuracy()))