Ejemplo n.º 1
0
 def __init__(self, job_dir, device, model, saver, saved_model_path, task):
     self.job_dir = job_dir
     self.device = device
     self.model = model
     self.logger = get_logger(job_dir)
     self.saver = saver
     self.saved_model_path = saved_model_path
     self.model.load_state_dict(
         torch.load(self.saved_model_path, map_location=self.device))
     self.predict_fn = task_predict_fn_dict[task]
Ejemplo n.º 2
0
 def __init__(self, job_dir, device, model, criterion, lr_scheduler, hooks,
              task):
     assert task in task_predict_fn_dict.keys()
     self.job_dir = job_dir
     self.device = device
     self.model = model
     self.criterion = criterion
     self.lr_scheduler = lr_scheduler
     self.hooks = hooks
     self.predict_fn = task_predict_fn_dict[task]
     self.logger = get_logger(job_dir)
     self.stop_signal = False
Ejemplo n.º 3
0
from torchvision.datasets import FashionMNIST
import torchvision.transforms as transforms
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from tqdm import tqdm, trange
from thop import profile

from trainer.models.cnn import CNN
from trainer.models.cnn_bn import CNNwithBN
from trainer.models.mobilenetv2 import MobileNetV2
from trainer.models.mobilenetv3 import MobileNetV3
from trainer.logger import get_logger

logger = get_logger()


class Trainer:
    def __init__(self, args: argparse.Namespace):
        self.args = args
        self.config = self.args.__dict__.copy()
        logger.info(self.config)
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.model = self._build_model()
        self.dataset = self._build_dataset()
        self.optimizer = self._build_optimizer()
        self.loss_fct = self._build_loss_fct()
        self.writer = SummaryWriter(self.args.output_dir)
Ejemplo n.º 4
0
import shutil
import sys
from argparse import ArgumentParser
from collections import Counter
from pathlib import Path
from zipfile import ZipFile

import numpy as np
import pandas as pd
import requests

from trainer.logger import get_logger

logger = get_logger(__name__)


def download_data(url="http://mattmahoney.net/dc/text8.zip", dest_dir="data"):
    # prepare destination
    dest = Path(dest_dir) / Path(url).name
    dest.parent.mkdir(parents=True, exist_ok=True)

    # downlaod zip
    if not dest.exists():
        logger.info("downloading file: %s.", url)
        r = requests.get(url, stream=True)
        with dest.open("wb") as f:
            shutil.copyfileobj(r.raw, f)
        logger.info("file downloaded: %s.", dest)

    # extract zip
    if not Path(dest_dir, "text8").exists():
Ejemplo n.º 5
0
import json
import os
import sys
from argparse import ArgumentParser

import tensorflow as tf

from trainer.logger import get_logger
from trainer.train_estimator_v1 import get_predict_input_fn, model_fn
from trainer.train_utils import get_estimator

logger = get_logger(__name__)


def format_predictions(predictions):
    embeddings = {}
    for instance in predictions:
        # need to convert byte to string
        # and remove numpy types
        item_id = instance["row_id"].decode()
        row_embed = instance["row_embed"].tolist()
        col_embed = instance["col_embed"].tolist()

        # add entry only if valid predictions
        if item_id != "<UNK>":
            instance_embedding = {
                "item_id": item_id,
                "row_embed": row_embed,
                "col_embed": col_embed
            }
            embeddings[item_id] = instance_embedding