def run(args):
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    logger = get_logger(os.path.join(args.logdir, 'main.log'))
    logger.info(args)

    # data
    source_transform = transforms.Compose([
        # transforms.Grayscale(),
        transforms.ToTensor()]
    )
    target_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ])
    source_dataset_train = SVHN(
        './input', 'train', transform=source_transform, download=True)
    target_dataset_train = MNIST(
        './input', train=True, transform=target_transform, download=True)
    target_dataset_test = MNIST(
        './input', train=False, transform=target_transform, download=True)
    source_train_loader = DataLoader(
        source_dataset_train, args.batch_size, shuffle=True,
        drop_last=True,
        num_workers=args.n_workers)
    target_train_loader = DataLoader(
        target_dataset_train, args.batch_size, shuffle=True,
        drop_last=True,
        num_workers=args.n_workers)
    target_test_loader = DataLoader(
        target_dataset_test, args.batch_size, shuffle=False,
        num_workers=args.n_workers)

    # train source CNN
    source_cnn = CNN(in_channels=args.in_channels).to(args.device)
    if os.path.isfile(args.trained):
        print("load model")
        c = torch.load(args.trained)
        source_cnn.load_state_dict(c['model'])
        logger.info('Loaded `{}`'.format(args.trained))
    else:
        print("not load model")

    # train target CNN
    target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
    target_cnn.load_state_dict(source_cnn.state_dict())
    discriminator = Discriminator(args=args).to(args.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        target_cnn.encoder.parameters(),
        lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    d_optimizer = optim.Adam(
        discriminator.parameters(),
        lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    train_target_cnn(
        source_cnn, target_cnn, discriminator,
        criterion, optimizer, d_optimizer,
        source_train_loader, target_train_loader, target_test_loader,
        args=args)
Exemplo n.º 2
0
class Predictor:
    def __init__(self, ckpt_path):
        super().__init__()
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(256),
            transforms.ToTensor()
        ])
        self.checkpoint = torch.load(ckpt_path)
        self.model = CNN()
        self.model.load_state_dict(state_dict=self.checkpoint['state_dict'])
        self.use_cuda = False
        self.model.eval()
        if torch.cuda.is_available():
            self.model.cuda()
            self.use_cuda = True

    def run(self, path):
        img = self.transform(io.imread(path))
        img.unsqueeze_(0)
        if self.use_cuda:
            img = img.cuda()
        with torch.no_grad():
            pred = self.model(img)
        return pred.item()
Exemplo n.º 3
0
def predict(img_dir='./data/test'):
    transforms = Compose([Resize(height, weight), ToTensor()])
    dataset = CaptchaData(img_dir, transform=transforms)
    cnn = CNN()
    if torch.cuda.is_available():
        cnn = cnn.cuda()
    cnn.eval()
    cnn.load_state_dict(torch.load(model_path))

    for k, (img, target) in enumerate(dataset):
        img = img.view(1, 3, height, weight).cuda()
        target = target.view(1, 4 * 36).cuda()
        output = cnn(img)

        output = output.view(-1, 36)
        target = target.view(-1, 36)
        output = nn.functional.softmax(output, dim=1)
        output = torch.argmax(output, dim=1)
        target = torch.argmax(target, dim=1)
        output = output.view(-1, 4)[0]
        target = target.view(-1, 4)[0]

        print('pred: ' + ''.join([alphabet[i] for i in output.cpu().numpy()]))
        print('true: ' + ''.join([alphabet[i] for i in target.cpu().numpy()]))

        plot.imshow(img.permute((0, 2, 3, 1))[0].cpu().numpy())
        plot.show()

        if k >= 10: break
Exemplo n.º 4
0
def post_predict():
    text = "failure"
    if request.method == 'POST':
        file = request.get_data()
        img = Image.open(BytesIO(file)).convert('RGB')
        print('宽:%d,高:%d' % (img.size[0], img.size[1]))
        width = img.size[0]
        height = img.size[1]
        transform = Compose([Resize(height, width), ToTensor()])
        img = transform(img)
        cnn = CNN()
        if torch.cuda.is_available():
            cnn = cnn.cuda()
        cnn.eval()
        cnn.load_state_dict(torch.load(model_path, map_location='cpu'))
        img = img.view(1, 3, height, width).cuda()
        output = cnn(img)
        output = output.view(-1, 36)
        output = nn.functional.softmax(output, dim=1)
        output = torch.argmax(output, dim=1)
        output = output.view(-1, 4)[0]
        text = ''.join([alphabet[i] for i in output.cpu().numpy()])
        # print('pred: '+text)
        # plot.imshow(img.permute((0, 2, 3, 1))[0].cpu().numpy())
        # plot.show()

    return text
Exemplo n.º 5
0
class Tester:
    """
    测试
    """
    def __init__(self, _hparams):
        self.test_loader = get_test_loader(_hparams)
        self.encoder = CNN().to(DEVICE)
        self.decoder = RNN(fea_dim=_hparams.fea_dim,
                           embed_dim=_hparams.embed_dim,
                           hid_dim=_hparams.hid_dim,
                           max_sen_len=_hparams.max_sen_len,
                           vocab_pkl=_hparams.vocab_pkl).to(DEVICE)
        self.test_cap = _hparams.test_cap

    def testing(self, save_path, test_path):
        """
        测试

        :param save_path: 模型的保存地址
        :param test_path: 保存测试过程生成句子的路径
        :return:
        """
        print('*' * 20, 'test', '*' * 20)
        self.load_models(save_path)
        self.set_eval()

        sen_json = []
        with torch.no_grad():
            for val_step, (img, img_id) in tqdm(enumerate(self.test_loader)):
                img = img.to(DEVICE)
                features = self.encoder.forward(img)
                sens, _ = self.decoder.sample(features)
                sen_json.append({'image_id': int(img_id), 'caption': sens[0]})

        with open(test_path, 'w') as f:
            json.dump(sen_json, f)

        result = coco_eval(self.test_cap, test_path)
        for metric, score in result:
            print(metric, score)

    def load_models(self, save_path):
        ckpt = torch.load(save_path,
                          map_location={'cuda:2': 'cuda:0'
                                        })  # 映射是因为解决保存模型的卡与加载模型的卡不一致的问题
        encoder_state_dict = ckpt['encoder_state_dict']
        self.encoder.load_state_dict(encoder_state_dict)
        decoder_state_dict = ckpt['decoder_state_dict']
        self.decoder.load_state_dict(decoder_state_dict)

    def set_eval(self):
        self.encoder.eval()
        self.decoder.eval()
Exemplo n.º 6
0
def load_net():
    global model_net
    model_net = CNN(num_class=len(alphabet),
                    num_char=int(numchar),
                    width=width,
                    height=height)
    if use_gpu:
        model_net = model_net.cuda()
        model_net.eval()
        model_net.load_state_dict(torch.load(model_path))
    else:
        model_net.eval()
        model_net.load_state_dict(torch.load(model_path, map_location='cpu'))
Exemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description="FGSM")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--data_root", type=str, default="./data")
    parser.add_argument("--data_name", type=str, default="mnist")
    parser.add_argument("--image_size", type=int, default=32)
    parser.add_argument("--image_channels", type=int, default=1)
    parser.add_argument("--epsilon", type=int, default=0.1)
    opt = parser.parse_args()

    model = CNN(opt.image_size, opt.image_channels).to(opt.device)
    model.load_state_dict(torch.load("./weights/cnn.pth"))

    test(model, opt)
Exemplo n.º 8
0
def init():
    global model, device

    try:
        model_path = Model.get_model_path('pytorch_mnist')
    except:
        model_path = 'model.pth'

    device = torch.device('cpu')

    model = CNN()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
Exemplo n.º 9
0
def init():
    global model, device

    try:
        model_path = Model.get_model_path('pytorch_mnist')
    except:
        model_path = 'model.pth'

    device = torch.device('cpu')

    model = CNN()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    print('Initialized model "{}" at {}'.format(model_path,
                                                datetime.datetime.now()))
Exemplo n.º 10
0
def predict(img_dir=Path('./captcha')):
    transforms = Compose([ToTensor()])
    for i in img_dir.glob("*png"):
        print(i.name)
        img = img_loader('./captcha/%s' % i.name)
        img = transforms(img)
        cnn = CNN()
        cnn.load_state_dict(torch.load(model_path))

        img = img.view(1, 3, 36, 120)
        output = cnn(img)

        output = output.view(-1, 36)
        output = nn.functional.softmax(output, dim=1)
        output = torch.argmax(output, dim=1)
        output = output.view(-1, num_class)[0]

        pred = ''.join([alphabet[i] for i in output.cpu().numpy()])
        print(pred)
Exemplo n.º 11
0
                    args.window_dim, len(data["lbl2idx"]), args.dp, emb)

    if args.fix_emb:
        model.embedding.weight.requires_grad = False

    loss = torch.nn.CrossEntropyLoss()
    optim = torch.optim.Adam(model.parameters(),
                             lr=args.lr,
                             weight_decay=args.l2)

    if args.cuda:
        model.cuda()
    trainModel(args, model, loss, optim, trainData, valData)

    if args.submit:
        # load the best model saved during training
        model.load_state_dict(
            torch.load(args.path_savedir +
                       "{}_{}.model".format(args.model, args.epochs)))
        model.eval()

        preds_val = predict(model, valData)
        save_prediction(
            args.path_savedir + "{}_{}.val".format(args.model, args.epochs),
            preds_val, data["idx2lbl"])
        testData = Dataset(data["test"], args.batch_size, args.cuda)
        preds_test = predict(model, testData)
        save_prediction(
            args.path_savedir + "{}_{}.test".format(args.model, args.epochs),
            preds_test, data["idx2lbl"])
Exemplo n.º 12
0
def main(args):
    # hyperparameters
    batch_size = args.batch_size
    num_workers = 2

    # Image Preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    vocab = load_vocab()

    loader = get_basic_loader(dir_path=os.path.join(args.image_path),
                              transform=transform,
                              batch_size=batch_size,
                              shuffle=False,
                              num_workers=num_workers)

    # Build the models
    embed_size = args.embed_size
    num_hiddens = args.num_hidden
    checkpoint_path = 'checkpoints'

    encoder = CNN(embed_size)
    decoder = RNN(embed_size,
                  num_hiddens,
                  len(vocab),
                  1,
                  rec_unit=args.rec_unit)

    encoder_state_dict, decoder_state_dict, optimizer, *meta = utils.load_models(
        args.checkpoint_file)
    encoder.load_state_dict(encoder_state_dict)
    decoder.load_state_dict(decoder_state_dict)

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()

    # Train the Models
    try:
        results = []
        with torch.no_grad():
            for step, (images, image_ids) in enumerate(tqdm(loader)):
                images = utils.to_var(images)

                features = encoder(images)
                captions = beam_sample(decoder, features)
                # captions = decoder.sample(features)
                captions = captions.cpu().data.numpy()
                captions = [
                    utils.convert_back_to_text(cap, vocab) for cap in captions
                ]
                captions_formatted = [{
                    'image_id': int(img_id),
                    'caption': cap
                } for img_id, cap in zip(image_ids, captions)]
                results.extend(captions_formatted)
                print('Sample:', captions_formatted[0])
    except KeyboardInterrupt:
        print('Ok bye!')
    finally:
        import json
        file_name = 'captions_model.json'
        with open(file_name, 'w') as f:
            json.dump(results, f)
def run(args):
    args.logdir = args.logdir + args.mode
    args.trained = args.trained + args.mode + '/best_model.pt'
    if not os.path.exists(args.logdir):
        os.makedirs(args.logdir)
    logger = get_logger(os.path.join(args.logdir, 'main.log'))
    logger.info(args)

    # data
    # source_transform = transforms.Compose([
    #     # transforms.Grayscale(),
    #     transforms.ToTensor()]
    # )
    # target_transform = transforms.Compose([
    #     transforms.Resize(32),
    #     transforms.ToTensor(),
    #     transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    # ])
    # source_dataset_train = SVHN(
    #     './input', 'train', transform=source_transform, download=True)
    # target_dataset_train = MNIST(
    #     './input', train=True, transform=target_transform, download=True)
    # target_dataset_test = MNIST(
    #     './input', train=False, transform=target_transform, download=True)
    # source_train_loader = DataLoader(
    #     source_dataset_train, args.batch_size, shuffle=True,
    #     drop_last=True,
    #     num_workers=args.n_workers)
    # target_train_loader = DataLoader(
    #     target_dataset_train, args.batch_size, shuffle=True,
    #     drop_last=True,
    #     num_workers=args.n_workers)
    # target_test_loader = DataLoader(
    #     target_dataset_test, args.batch_size, shuffle=False,
    #     num_workers=args.n_workers)
    batch_size = 128
    if args.mode == 'm2mm':
        source_dataset_name = 'MNIST'
        target_dataset_name = 'mnist_m'
        source_image_root = os.path.join('dataset', source_dataset_name)
        target_image_root = os.path.join('dataset', target_dataset_name)
        image_size = 28
        img_transform_source = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.1307, ), std=(0.3081, ))
        ])

        img_transform_target = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])

        dataset_source = datasets.MNIST(root='dataset',
                                        train=True,
                                        transform=img_transform_source,
                                        download=True)

        train_list = os.path.join(target_image_root,
                                  'mnist_m_train_labels.txt')

        dataset_target_train = GetLoader(data_root=os.path.join(
            target_image_root, 'mnist_m_train'),
                                         data_list=train_list,
                                         transform=img_transform_target)

        test_list = os.path.join(target_image_root, 'mnist_m_test_labels.txt')

        dataset_target_test = GetLoader(data_root=os.path.join(
            target_image_root, 'mnist_m_test'),
                                        data_list=test_list,
                                        transform=img_transform_target)
    elif args.mode == 's2u':
        dataset_source = svhn.SVHN('./data/svhn/',
                                   split='train',
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.Resize(28),
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]))

        dataset_target_train = usps.USPS('./data/usps/',
                                         train=True,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             transforms.Normalize((0.5, ),
                                                                  (0.5, ))
                                         ]))
        dataset_target_test = usps.USPS('./data/usps/',
                                        train=False,
                                        download=True,
                                        transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5, ),
                                                                 (0.5, ))
                                        ]))
        source_dataset_name = 'svhn'
        target_dataset_name = 'usps'

    source_train_loader = torch.utils.data.DataLoader(dataset=dataset_source,
                                                      batch_size=batch_size,
                                                      shuffle=True,
                                                      num_workers=8)

    target_train_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8)

    target_test_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8)
    # train source CNN
    source_cnn = CNN(in_channels=args.in_channels).to(args.device)
    if os.path.isfile(args.trained):
        print("load model")
        c = torch.load(args.trained)
        source_cnn.load_state_dict(c['model'])
        logger.info('Loaded `{}`'.format(args.trained))
    else:
        print("not load model")

    # train target CNN
    target_cnn = CNN(in_channels=args.in_channels, target=True).to(args.device)
    target_cnn.load_state_dict(source_cnn.state_dict())
    discriminator = Discriminator(args=args).to(args.device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(target_cnn.encoder.parameters(), lr=args.lr)
    # optimizer = optim.Adam(
    #     target_cnn.encoder.parameters(),
    #     lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    d_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr)
    # d_optimizer = optim.Adam(
    #     discriminator.parameters(),
    #     lr=args.lr, betas=args.betas, weight_decay=args.weight_decay)
    train_target_cnn(source_cnn,
                     target_cnn,
                     discriminator,
                     criterion,
                     optimizer,
                     d_optimizer,
                     source_train_loader,
                     target_train_loader,
                     target_test_loader,
                     args=args)
Exemplo n.º 14
0
N_FILTERS = 100
FILTER_SIZES = [3, 4, 5]
OUTPUT_DIM = 1
DROPOUT = 0.5

model = CNN(INPUT_DIM, EMBEDDING_DIM, N_FILTERS, FILTER_SIZES, OUTPUT_DIM,
            DROPOUT)

# def get_preds(model, iterator):

#                 yield pred

ck1 = time.time()
print '1: %.3f' % (ck1 - start)

model.load_state_dict(torch.load('pretrained_model.pt'))

import pdb
pdb.set_trace()

model.eval()
all_preds = np.zeros(len(test_data))
all_ids = np.zeros(len(test_data))
all_labels = np.zeros(len(test_data))

with torch.no_grad():
    i = 0
    for batch in test_iterator:
        predictions = model(batch.text).squeeze(1)

        for j, pred in enumerate(predictions):
Exemplo n.º 15
0
import torch.nn as nn
from torchvision.transforms.functional import to_tensor
from torchvision.transforms import Compose, ToTensor
from datasets import CaptchaData
from models import CNN
from PIL import Image

source = [str(i) for i in range(0, 10)]
source += [chr(i) for i in range(97, 97 + 26)]
model_path = './model.pth'

cnn = CNN()
if torch.cuda.is_available():
    cnn = cnn.cuda()
    cnn.eval()
    cnn.load_state_dict(torch.load(model_path))
else:
    cnn.eval()
    cnn.load_state_dict(torch.load(model_path, map_location='cpu'))


# img_path:单张图片路径
def captchaByPath(img_path):
    img = Image.open(img_path)
    img = to_tensor(img)
    if torch.cuda.is_available():
        img = img.view(1, 3, 32, 120).cuda()
    else:
        img = img.view(1, 3, 32, 120)
    output = cnn(img)
    output = output.view(-1, 36)
Exemplo n.º 16
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.config = config
        self.data_loader = data_loader

    def build(self, is_train):
        if torch.cuda.is_available():
            self.model = nn.DataParallel(CNN(self.config)).cuda()
        else:
            self.model = CNN(self.config)

        self.loss_fn = self.config.loss_fn()

        if is_train:
            self.model.train()
            self.optimizer = self.config.optimizer(self.model.parameters(), lr=self.config.lr)
        else:
            if torch.cuda.is_available():
                self.model = self.model.module
            self.model.eval()

    def save(self, ckpt_path):
        """Save model parameters"""
        print('Save parameters at ', ckpt_path)

        if torch.cuda.is_available():
            torch.save(self.model.module.state_dict(), ckpt_path)
        else:
            torch.save(self.model.state_dict(), ckpt_path)

    def load(self, ckpt_path=None, epoch=None):
        """Load model parameters"""
        if not (ckpt_path or epoch):
            epoch = self.config.epochs
        if epoch:
            ckpt_path = os.path.join(self.config.save_dir, f'epoch-{epoch}.pkl')
        print('Load parameters from ', ckpt_path)
        print (self.model)

        self.model.load_state_dict(torch.load(ckpt_path))

    def train_once(self):
        loss_history = []

        for batch_i, batch in enumerate(tqdm(self.data_loader)):
            text, label = batch.text, batch.label

            if torch.cuda.is_available():
                text = text.cuda()
                label = label.cuda()

            text.data.t_()

            logit = self.model(text)

            average_batch_loss = self.loss_fn(logit, label)
            loss_history.append(average_batch_loss.item())

            self.optimizer.zero_grad()

            average_batch_loss.backward()

            self.optimizer.step()

        epoch_loss = np.mean(loss_history)

        return epoch_loss


    def train(self):
        """Train model with training data"""
        for epoch in tqdm(range(self.config.epochs)):
            loss_history = []

            for batch_i, batch in enumerate(tqdm(self.data_loader)):
                # text: [max_seq_len, batch_size]
                # label: [batch_size]
                text, label = batch.text, batch.label

                if torch.cuda.is_available():
                    text = text.cuda()
                    label = label.cuda()

                # [batch_size, max_seq_len]
                text.data.t_()

                # [batch_size, 2]
                logit = self.model(text)

                # Calculate loss
                average_batch_loss = self.loss_fn(logit, label)  # [1]
                loss_history.append(average_batch_loss.item())  # Variable -> Tensor

                # Flush out remaining gradient
                self.optimizer.zero_grad()

                # Backpropagation
                average_batch_loss.backward()

                # Gradient descent
                self.optimizer.step()

            # Log intermediate loss
            if (epoch + 1) % self.config.log_every_epoch == 0:
                epoch_loss = np.mean(loss_history)
                log_str = f'Epoch {epoch + 1} | loss: {epoch_loss:.4f}\n'
                print(log_str)

            # Save model parameters
            if (epoch + 1) % self.config.save_every_epoch == 0:
                ckpt_path = os.path.join(self.config.save_dir, f'epoch-{epoch+1}.pkl')
                self.save(ckpt_path)

    def eval(self):
        """Evaluate model from text data"""

        n_total_data = 0
        n_correct = 0
        loss_history = []

        '''
        import ipdb
        ipdb.set_trace()
        '''

        for _, batch in enumerate(tqdm(self.data_loader)):
            # text: [max_seq_len, batch_size]
            # label: [batch_size]
            text, label = batch.text, batch.label

            if torch.cuda.is_available():
                text = text.cuda()
                label = label.cuda()

            # [batch_size, max_seq_len]
            text.data.t_()

            # [batch_size, 2]
            logit = self.model(text)

            # Calculate loss
            average_batch_loss = self.loss_fn(logit, label)  # [1]
            loss_history.append(average_batch_loss.item())  # Variable -> Tensor

            # Calculate accuracy
            n_total_data += len(label)

            # [batch_size]
            _, prediction = logit.max(1)

            n_correct += (prediction == label).sum().data

        epoch_loss = np.mean(loss_history)

        accuracy = n_correct.item() / float(n_total_data)

        print(f'Loss: {epoch_loss:.2f}')
        print(f'Accuracy: {accuracy}')

        return epoch_loss, accuracy

    def inference(self, text):
        text = Variable(torch.LongTensor([text]))

        # [batch_size, 2]
        logit = self.model(text)

        _, prediction = torch.max(logit)

        return prediction

    def train_eval(self):
        # Set this variable to your MLflow server's DNS name
        mlflow_server = '172.23.147.124'

        # Tracking URI
        mlflow_tracking_URI = 'http://' + mlflow_server + ':5000'
        print ("MLflow Tracking URI: %s" % (mlflow_tracking_URI))

        with mlflow.start_run():
            for key, value in vars(self.config).items():
                mlflow.log_param(key, value)

            '''
            output_dir = 'mlflow_logs'
            if not os.path.exists(output_dir):
                os.mkdir(output_dir)
            '''

            for epoch in tqdm(range(self.config.epochs)):
                # print out active_run
                print("Active Run ID: %s, Epoch: %s \n" % (mlflow.active_run(), epoch))

                train_loss = self.train_once()
                mlflow.log_metric('train_loss', train_loss)

                val_loss, val_acc = self.eval()
                mlflow.log_metric('val_loss', val_loss)
                mlflow.log_metric('val_acc', val_acc)

        # Finish run
        mlflow.end_run(status='FINISHED')
Exemplo n.º 17
0
def main(args):
    # hyperparameters
    batch_size = args.batch_size
    num_workers = 1

    # Image Preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])

    # load COCOs dataset
    IMAGES_PATH = 'data/train2014'
    CAPTION_FILE_PATH = 'data/annotations/captions_train2014.json'

    vocab = load_vocab()
    train_loader = get_coco_data_loader(path=IMAGES_PATH,
                                        json=CAPTION_FILE_PATH,
                                        vocab=vocab,
                                        transform=transform,
                                        batch_size=batch_size,
                                        shuffle=True,
                                        num_workers=num_workers)

    IMAGES_PATH = 'data/val2014'
    CAPTION_FILE_PATH = 'data/annotations/captions_val2014.json'
    val_loader = get_coco_data_loader(path=IMAGES_PATH,
                                      json=CAPTION_FILE_PATH,
                                      vocab=vocab,
                                      transform=transform,
                                      batch_size=batch_size,
                                      shuffle=True,
                                      num_workers=num_workers)

    losses_val = []
    losses_train = []

    # Build the models
    ngpu = 1
    initial_step = initial_epoch = 0
    embed_size = args.embed_size
    num_hiddens = args.num_hidden
    learning_rate = 1e-3
    num_epochs = 3
    log_step = args.log_step
    save_step = 500
    checkpoint_dir = args.checkpoint_dir

    encoder = CNN(embed_size)
    decoder = RNN(embed_size,
                  num_hiddens,
                  len(vocab),
                  1,
                  rec_unit=args.rec_unit)

    # Loss
    criterion = nn.CrossEntropyLoss()

    if args.checkpoint_file:
        encoder_state_dict, decoder_state_dict, optimizer, *meta = utils.load_models(
            args.checkpoint_file, args.sample)
        initial_step, initial_epoch, losses_train, losses_val = meta
        encoder.load_state_dict(encoder_state_dict)
        decoder.load_state_dict(decoder_state_dict)
    else:
        params = list(decoder.parameters()) + list(
            encoder.linear.parameters()) + list(encoder.batchnorm.parameters())
        optimizer = torch.optim.Adam(params, lr=learning_rate)

    if torch.cuda.is_available():
        encoder.cuda()
        decoder.cuda()

    if args.sample:
        return utils.sample(encoder, decoder, vocab, val_loader)

    # Train the Models
    total_step = len(train_loader)
    try:
        for epoch in range(initial_epoch, num_epochs):

            for step, (images, captions,
                       lengths) in enumerate(train_loader, start=initial_step):

                # Set mini-batch dataset
                images = utils.to_var(images, volatile=True)
                captions = utils.to_var(captions)
                targets = pack_padded_sequence(captions,
                                               lengths,
                                               batch_first=True)[0]

                # Forward, Backward and Optimize
                decoder.zero_grad()
                encoder.zero_grad()

                if ngpu > 1:
                    # run on multiple GPU
                    features = nn.parallel.data_parallel(
                        encoder, images, range(ngpu))
                    outputs = nn.parallel.data_parallel(
                        decoder, features, range(ngpu))
                else:
                    # run on single GPU
                    features = encoder(images)
                    outputs = decoder(features, captions, lengths)

                train_loss = criterion(outputs, targets)
                losses_train.append(train_loss.data[0])
                train_loss.backward()
                optimizer.step()

                # Run validation set and predict
                if step % log_step == 0:
                    encoder.batchnorm.eval()
                    # run validation set
                    batch_loss_val = []
                    for val_step, (images, captions,
                                   lengths) in enumerate(val_loader):
                        images = utils.to_var(images, volatile=True)
                        captions = utils.to_var(captions, volatile=True)

                        targets = pack_padded_sequence(captions,
                                                       lengths,
                                                       batch_first=True)[0]
                        features = encoder(images)
                        outputs = decoder(features, captions, lengths)
                        val_loss = criterion(outputs, targets)
                        batch_loss_val.append(val_loss.data[0])

                    losses_val.append(np.mean(batch_loss_val))

                    # predict
                    sampled_ids = decoder.sample(features)
                    sampled_ids = sampled_ids.cpu().data.numpy()[0]
                    sentence = utils.convert_back_to_text(sampled_ids, vocab)
                    print('Sample:', sentence)

                    true_ids = captions.cpu().data.numpy()[0]
                    sentence = utils.convert_back_to_text(true_ids, vocab)
                    print('Target:', sentence)

                    print(
                        'Epoch: {} - Step: {} - Train Loss: {} - Eval Loss: {}'
                        .format(epoch, step, losses_train[-1], losses_val[-1]))
                    encoder.batchnorm.train()

                # Save the models
                if (step + 1) % save_step == 0:
                    utils.save_models(encoder, decoder, optimizer, step, epoch,
                                      losses_train, losses_val, checkpoint_dir)
                    utils.dump_losses(
                        losses_train, losses_val,
                        os.path.join(checkpoint_dir, 'losses.pkl'))

    except KeyboardInterrupt:
        pass
    finally:
        # Do final save
        utils.save_models(encoder, decoder, optimizer, step, epoch,
                          losses_train, losses_val, checkpoint_dir)
        utils.dump_losses(losses_train, losses_val,
                          os.path.join(checkpoint_dir, 'losses.pkl'))
Exemplo n.º 18
0
if MODEL == 'CNN':
    from models import CNN
    model = CNN()
elif MODEL == 'MLP':
    from models import MLP
    model = MLP()
else:
    raise NotImplementedError("You need to choose among [CNN, MLP].")

trained_model = torch.load(
    '{}.pt'.format(MODEL))  # train_20190316_1 에서 저장한 가중치 파일 mlp.pt를 불러옴
state_dict = trained_model.state_dict(
)  # 훈련된 모델을 불러온 trained_model 의 상태(weight, bias)를 state_dict로 저장
# dictionary 데이터이므로 state_dict.keys(), .values(), .items() 등으로 읽을 수 있음
model.load_state_dict(
    state_dict)  # MLP 클래스의 모델 구조를 가진 mlp 객체에 state_dict의 가중치&바이어스를 입력함

nb_correct_answers = 0  # 정답을 맞춘 갯수를 저장하기 위한 객체
for data in data_loader:  # data 객체로 data_loader의 성분을 불러옴
    input, label = data[0], data[1]  # data의 인풋과 라벨을 불러옴
    input = input.view(input.shape[0],
                       -1) if model == 'MLP' else input  # 인풋 데이터를 1차원 행렬로 만듬
    # [batch size, channel*height*weight]
    classification_results = model(
        input)  # mlp.pt의 가중치를 입력받은 MLP 클래스의 인스턴스 mlp 에 인풋을 입력
    nb_correct_answers += torch.eq(classification_results.argmax(),
                                   label).sum()  # torch.eq(x,y): x,y가 같으면 1 출력
    # .sum을 통해 한 batch 내에서 맞춘 갯수를 출력. .sum() 없으면 batch_size의 행렬에 각 성분마다 정답이면1, 정답이 아니면 0 돌려줌.([0,0,1,1,0,0,1])
print("Average acc.: {} %.".format(
    float(nb_correct_answers) / len(data_loader) * 100))  # 정답률 출력
Exemplo n.º 19
0
    full_test_cm = get_cm(torch.Tensor(true_labels), torch.Tensor(predictions))
    msg = 'Finished Training. Test Accuracy : {:.2f} Mean Loss : {:.2f}'.format(
        (full_test_cm.diag().sum() / full_test_cm.sum()).item(),
        full_test_loss)
    print(msg)
    logging.info(msg)

    #save model
    if True:
        torch.save(
            {
                'epoch': epoch,
                'model_state_dict': network.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss,
            }, r"src/saved_models/model2.pkl")

        #load example
        checkpoint = torch.load(r"src/saved_models/model2.pkl")
        network.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']

    #analyze results
    activity_labels = pd.read_csv("input/LabelMap.csv", index_col=0)
    test_df = pd.DataFrame(full_test_cm.long().numpy(),
                           columns=activity_labels.Activity,
                           index=activity_labels.Activity)
    test_df.to_csv('src/ConfusionMatrixTest__lr_scheduler.csv')
Exemplo n.º 20
0
                args.num_filters) + "_" + str(args.batch_size) + "_" + str(
                    args.rate) + ".model"

    if args.gpu:
        model = model.cuda()

    # Training setup
    L = t.nn.CrossEntropyLoss()
    optimizer = t.optim.Adam(model.parameters(), lr=args.learn_rate)

    if not os.path.exists("models"):
        os.makedirs("models")

    # load to continue with pre-existing model
    if os.path.exists(fname):
        model.load_state_dict(t.load(fname))
        print("Successfully loaded previous model " + str(fname))

    # start with a model defined on 0
#    train_mix, test, train_data, train_labels = dataFetch()
#    # select only 0 category
#    train_dataset = customDataset(train_data[0], train_labels[0])
#
#    # define train and test as DataLoaders
#    train_loader = t.utils.data.DataLoader(dataset=train_dataset,
#                                           batch_size=args.batch_size,
#                                           shuffle=True)
#    test_loader = test

    train_data, test_data = dataFetch(dset=args.data)
Exemplo n.º 21
0
    return {
        'loss': losses.avg,
        'acc': acc,
    }


class ARG:
    def __init__(self):
        self.device = 'cuda:0'


if __name__ == "__main__":
    args = ARG()
    target_cnn = CNN(in_channels=3, target=True).to(args.device)
    c = torch.load('outputs/garbage1/best_model.pt')
    target_cnn.load_state_dict(c['model'])
    target_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1))
    ])
    criterion = nn.CrossEntropyLoss()
    target_dataset_test = MNIST('./input',
                                train=False,
                                transform=target_transform,
                                download=True)
    target_test_loader = DataLoader(target_dataset_test,
                                    128,
                                    shuffle=False,
                                    num_workers=4)
    print(len(target_test_loader))
Exemplo n.º 22
0
def train():
    transforms = Compose([ToTensor()])
    train_dataset = CaptchaData('./data/train', transform=transforms)
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   num_workers=0,
                                   shuffle=True,
                                   drop_last=True)
    test_data = CaptchaData('./data/test', transform=transforms)
    test_data_loader = DataLoader(test_data,
                                  batch_size=batch_size,
                                  num_workers=0,
                                  shuffle=True,
                                  drop_last=True)
    cnn = CNN()
    if torch.cuda.is_available():
        cnn.cuda()
    if restor:
        cnn.load_state_dict(torch.load(model_path))


#        freezing_layers = list(cnn.named_parameters())[:10]
#        for param in freezing_layers:
#            param[1].requires_grad = False
#            print('freezing layer:', param[0])

    optimizer = torch.optim.Adam(cnn.parameters(), lr=base_lr)
    criterion = nn.MultiLabelSoftMarginLoss()

    for epoch in range(max_epoch):
        start_ = time.time()

        loss_history = []
        acc_history = []
        cnn.train()
        for img, target in train_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = cnn(img)
            loss = criterion(output, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            acc = calculat_acc(output, target)
            acc_history.append(acc)
            loss_history.append(loss)
        print('train_loss: {:.4}|train_acc: {:.4}'.format(
            torch.mean(torch.Tensor(loss_history)),
            torch.mean(torch.Tensor(acc_history)),
        ))

        loss_history = []
        acc_history = []
        cnn.eval()
        for img, target in test_data_loader:
            img = Variable(img)
            target = Variable(target)
            if torch.cuda.is_available():
                img = img.cuda()
                target = target.cuda()
            output = cnn(img)

            acc = calculat_acc(output, target)
            acc_history.append(acc)
            loss_history.append(float(loss))
        print('test_loss: {:.4}|test_acc: {:.4}'.format(
            torch.mean(torch.Tensor(loss_history)),
            torch.mean(torch.Tensor(acc_history)),
        ))
        print('epoch: {}|time: {:.4f}'.format(epoch, time.time() - start_))
        torch.save(cnn.state_dict(), model_path)
Exemplo n.º 23
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.config = config
        self.data_loader = data_loader

    def build(self, is_train):
        self.model = CNN(self.config)
        self.loss_fn = self.config.loss_fn()

        if is_train:
            self.model.train()
            self.optimizer = self.config.optimizer(self.model.parameters(),
                                                   lr=self.config.lr)
        else:
            self.model.eval()

    def train(self):
        for epoch in tqdm(range(self.config.epochs)):
            loss_history = []

            for batch_i, batch in enumerate(tqdm(self.data_loader)):
                # text: [max_seq_len, batch_size]
                # label: [batch_size]
                text, label = batch.text, batch.label

                # [batch_size, max_seq_len]
                text.data.t_()

                # [batch_size, 2]
                logit = self.model(text)

                # Calculate loss
                average_batch_loss = self.loss_fn(logit, label)  # [1]
                loss_history.append(
                    average_batch_loss.data[0])  # Variable -> Tensor

                # Flush out remaining gradient
                self.optimizer.zero_grad()

                # Backpropagation
                average_batch_loss.backward()

                # Gradient descent
                self.optimizer.step()

            # Log intermediate loss
            if (epoch + 1) % self.config.log_every_epoch == 0:
                epoch_loss = np.mean(loss_history)
                log_str = f'Epoch {epoch + 1} | loss: {epoch_loss:.2f}\n'
                print(log_str)

            # Save model parameters
            if (epoch + 1) % self.config.save_every_epoch == 0:
                ckpt_path = os.path.join(self.config.save_dir,
                                         f'epoch-{epoch+1}.pkl')
                print('Save parameters at ', ckpt_path)
                torch.save(self.model.state_dict(), ckpt_path)

    def eval(self, epoch=None):

        # Load model parameters
        if not isinstance(epoch, int):
            epoch = self.config.epochs
        ckpt_path = os.path.join(self.config.save_dir, f'epoch-{epoch}.pkl')
        print('Load parameters from ', ckpt_path)
        self.model.load_state_dict(torch.load(ckpt_path))

        loss_history = []
        for _, batch in tqdm(enumerate(self.data_loader)):
            # text: [max_seq_len, batch_size]
            # label: [batch_size]
            text, label = batch.text, batch.label

            # [batch_size, max_seq_len]
            text.data.t_()

            # [batch_size, 2]
            logit = self.model(text)

            # Calculate loss
            average_batch_loss = self.loss_fn(logit, label)  # [1]
            loss_history.append(
                average_batch_loss.data[0])  # Variable -> Tensor

        epoch_loss = np.mean(loss_history)

        print('Loss: {epoch_loss:.2f}')
    if device == "cuda":
        print(
            "Using GPU. Setting default tensor type to torch.cuda.FloatTensor")
        torch.set_default_tensor_type("torch.cuda.FloatTensor")
    else:
        print("Using CPU. Setting default tensor type to torch.FloatTensor")
        torch.set_default_tensor_type("torch.FloatTensor")
    """Converting model to specified hardware and format"""
    acoustic_cfg_json = json.load(
        open(args.acoustic_model.replace(".torch", ".json"), "r"))
    acoustic_cfg = AcousticConfig.from_json(acoustic_cfg_json)

    acoustic_model = CNN(acoustic_cfg)
    acoustic_model.float().to(device)
    try:
        acoustic_model.load_state_dict(torch.load(args.acoustic_model))
    except:
        print(
            "Failed to load model from {} without device mapping. Trying to load with mapping to {}"
            .format(args.acoustic_model, device))
        acoustic_model.load_state_dict(
            torch.load(args.acoustic_model, map_location=device))
    linguistic_cfg_json = json.load(
        open(args.linguistic_model.replace(".torch", ".json"), "r"))
    linguistic_cfg = LinguisticConfig.from_json(linguistic_cfg_json)

    linguistic_model = AttentionModel(linguistic_cfg)
    linguistic_model.float().to(device)

    try:
        linguistic_model.load_state_dict(torch.load(args.linguistic_model))
Exemplo n.º 25
0
    pickle.dump(params, open(init_params_file, 'wb'))


# initialize accumulators.
current_epoch = 1
batch_step_count = 1
time_used_global = 0.0
checkpoint = 1


# load lastest model to resume training
model_list = os.listdir(model_dir)
if model_list:
    print('Loading lastest checkpoint...')
    state = load_model(model_dir, model_list)
    encoder.load_state_dict(state['encoder'])
    decoder.load_state_dict(state['decoder'])
    optimizer.load_state_dict(state['optimizer'])
    current_epoch = state['epoch'] + 1
    time_used_global = state['time_used_global']
    batch_step_count = state['batch_step_count']

for group in optimizer.param_groups:
    group['lr'] = 0.0000001
    group['weight_decay'] = 0.0

for param in encoder.parameters():
    param.requires_grad_(requires_grad=True)

BATCH_SIZE = 16
Exemplo n.º 26
0
app = Flask(__name__,
            static_folder="./backend/static",
            template_folder="./backend/templates")
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
bootstrap = Bootstrap(app)

# 处理跨域请求
cors = CORS(app, resources={r"/*": {"origins": "*"}})

class2index = json.load(open("class_index.json"))
index2class = {v: k for k, v in class2index.items()}

# 在此处声明,保证模型只在app初始化时加载
device = torch.device('cpu')
model = CNN()
model.load_state_dict(torch.load('pickles/cnn.pkl'))
model.to(device=device)

trans = Compose([ToTensor(), PaddingSame2d(seq_len=224, value=0)])


def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS


def transform_audio(audio_bytes):
    feat = get_mfcc(audio_bytes)
    feat = trans(feat)
    return feat
Exemplo n.º 27
0
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.5,), (0.5,))
                                        ]))


    target_test_loader = torch.utils.data.DataLoader(
        dataset=dataset_target_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=8
    )


    checkpoint = torch.load(args.base_classifier)
    base_classifier = CNN(in_channels=3, target=True).to(device)
    base_classifier.load_state_dict(checkpoint['model'])

    # create the smooothed classifier g
    smoothed_classifier = Smooth(base_classifier, num_classes= 10, sigma=args.sigma)

    # prepare output file
    # f = open(args.outfile, 'w')
    # print("idx\tlabel\tpredict\tradius\tcorrect\ttime", file=f, flush=True)

    # iterate through the dataset
    n_total = 0
    n_correct = 0
    thresh_list = [0,0.5,1.0,1.5,2.0,2.5,3.0]
    correct_list = [0]*7
    for i,data in enumerate(target_test_loader):
        if i % 100 == 0:
Exemplo n.º 28
0
from models import CNN
from datasets import img_loader
from torchvision.transforms import Compose, ToTensor
from train import num_class
from pathlib import Path

model_path = './checkpoints/model.pth'

source = [str(i) for i in range(0, 10)]
source += [chr(i) for i in range(97, 97 + 26)]
alphabet = ''.join(source)

transforms = Compose([ToTensor()])
img = img_loader('captcha_test/captcha_crop.png')
img = transforms(img)
cnn = CNN()
if torch.cuda.is_available():
    cnn = cnn.cuda()
cnn.load_state_dict(torch.load(model_path))
print("---------")
img = img.view(1, 3, 36, 120).cuda()
output = cnn(img)

output = output.view(-1, 36)
output = nn.functional.softmax(output, dim=1)
output = torch.argmax(output, dim=1)
output = output.view(-1, num_class)[0]

pred = ''.join([alphabet[i] for i in output.cpu().numpy()])
print(pred)
Exemplo n.º 29
0
# -*- coding: utf-8 -*-
"""
Created on Tue Feb 25 12:01:40 2020

@author: Ernest Namdar

Ref:
    https://github.com/onnx/tutorials/blob/master/tutorials/PytorchTensorflowMnist.ipynb
"""

from torch.autograd import Variable
import torch
from models import CNN


# Load the trained model
PyTorch_model = CNN()
PyTorch_model.load_state_dict(torch.load("./CNN_State.pt"))

# Export the trained model to ONNX
dummy_input = Variable(torch.randn(1, 1, 28, 28)) # one single-channel 28x28 picture will be the input to the model
torch.onnx.export(PyTorch_model, dummy_input, "./CNN.onnx")
Exemplo n.º 30
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.config = config
        self.data_loader = data_loader

    def build(self, is_train):
        self.model = CNN(self.config)
        self.loss_fn = self.config.loss_fn()

        if is_train:
            self.model.train()
            self.optimizer = self.config.optimizer(self.model.parameters(), lr=self.config.lr)
        else:
            self.model.eval()

    def save(self, ckpt_path):
        """Save model parameters"""
        print('Save parameters at ', ckpt_path)
        torch.save(self.model.state_dict(), ckpt_path)

    def load(self, ckpt_path=None, epoch=None):
        """Load model parameters"""
        if not (ckpt_path or epoch):
            epoch = self.config.epochs
        if epoch:
            ckpt_path = os.path.join(self.config.save_dir, f'epoch-{epoch}.pkl')
        print('Load parameters from ', ckpt_path)
        self.model.load_state_dict(torch.load(ckpt_path))

    def train(self):
        """Train model with training data"""
        for epoch in tqdm(range(self.config.epochs)):
            loss_history = []

            for batch_i, batch in enumerate(tqdm(self.data_loader)):
                # text: [max_seq_len, batch_size]
                # label: [batch_size]
                text, label = batch.text, batch.label

                # [batch_size, max_seq_len]
                text.data.t_()

                # [batch_size, 2]
                logit = self.model(text)

                # Calculate loss
                average_batch_loss = self.loss_fn(logit, label)  # [1]
                loss_history.append(average_batch_loss.data[0])  # Variable -> Tensor

                # Flush out remaining gradient
                self.optimizer.zero_grad()

                # Backpropagation
                average_batch_loss.backward()

                # Gradient descent
                self.optimizer.step()

            # Log intermediate loss
            if (epoch + 1) % self.config.log_every_epoch == 0:
                epoch_loss = np.mean(loss_history)
                log_str = f'Epoch {epoch + 1} | loss: {epoch_loss:.2f}\n'
                print(log_str)

            # Save model parameters
            if (epoch + 1) % self.config.save_every_epoch == 0:
                ckpt_path = os.path.join(self.config.save_dir, f'epoch-{epoch+1}.pkl')
                self.save(ckpt_path)

    def eval(self):
        """Evaluate model from text data"""

        n_total_data = 0
        n_correct = 0
        loss_history = []
        import ipdb
        ipdb.set_trace()
        for _, batch in enumerate(tqdm(self.data_loader)):
            # text: [max_seq_len, batch_size]
            # label: [batch_size]
            text, label = batch.text, batch.label

            # [batch_size, max_seq_len]
            text.data.t_()

            # [batch_size, 2]
            logit = self.model(text)

            # Calculate loss
            average_batch_loss = self.loss_fn(logit, label)  # [1]
            loss_history.append(average_batch_loss.data[0])  # Variable -> Tensor

            # Calculate accuracy
            n_total_data += len(label)

            # [batch_size]
            _, prediction = logit.max(1)

            n_correct += (prediction == label).sum().data

        epoch_loss = np.mean(loss_history)

        accuracy = n_correct / n_total_data

        print(f'Loss: {epoch_loss:.2f}')

        print(f'Accuracy: {accuracy}')

    def inference(self, text):

        text = Variable(torch.LongTensor([text]))

        # [batch_size, 2]
        logit = self.model(text)

        _, prediction = torch.max(logit)

        return prediction