コード例 #1
0
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)
コード例 #2
0
    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.writer = TensorboardWriter(str(self.config.log_dir))
コード例 #3
0
    def build(self):

        # Build Modules
        # self.summarizer = simple_encoder_LSTM(
        #     input_size=self.config.input_size,
        #     hidden_size=self.config.hidden_size,
        #     num_layers=self.config.num_layers).cuda()
        self.summarizer = attentive_encoder_LSTM(
            input_size=self.config.input_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        # self.summarizer = attentive_encoder_decoder_LSTM(
        #     input_size=self.config.input_size,
        #     hidden_size=self.config.hidden_size,
        #     num_layers=self.config.num_layers).cuda()

        if self.config.mode == 'train':
            # Build Optimizers
            self.optimizer = optim.Adam(self.summarizer.parameters(),
                                        lr=self.config.lr)

            self.summarizer.train()

            # Tensorboard
            self.writer = TensorboardWriter(self.config.log_dir)
コード例 #4
0
    def __init__(self, model, criterion, optimizer, scheduler, metric_ftns,
                 device, num_epoch, grad_clipping, grad_accumulation_steps,
                 early_stopping, validation_frequency, tensorboard,
                 checkpoint_dir, resume_path):
        self.device, device_ids = self._prepare_device(device)
        # self.model = model.to(self.device)

        self.start_epoch = 1
        if resume_path is not None:
            self._resume_checkpoint(resume_path)
        if len(device_ids) > 1:
            # self.model = torch.nn.DataParallel(model, device_ids=device_ids)
            self.model = torch.nn.DataParallel(model)
            # cudnn.benchmark = True
        if use_cuda:
            self.model = model.cuda()
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.num_epoch = num_epoch
        self.scheduler = scheduler
        self.grad_clipping = grad_clipping
        self.grad_accumulation_steps = grad_accumulation_steps
        self.early_stopping = early_stopping
        self.validation_frequency = validation_frequency
        self.checkpoint_dir = checkpoint_dir
        self.best_epoch = 1
        self.best_score = 0
        self.writer = TensorboardWriter(
            os.path.join(checkpoint_dir, 'tensorboard'), tensorboard)
        self.train_metrics = MetricTracker('loss', writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
コード例 #5
0
ファイル: solver.py プロジェクト: herok97/2021-1_Lecture
    def build(self):
        # 내가 추가한 코드
        torch.backends.cudnn.enabled = False

        # 내가 추가한 코드 / GPU 정보
        USE_CUDA = torch.cuda.is_available()
        print(USE_CUDA)
        device = torch.device('cuda:0' if USE_CUDA else 'cpu')
        print('학습을 진행하는 기기:', device)
        print('cuda index:', torch.cuda.current_device())
        print('gpu 개수:', torch.cuda.device_count())
        print('graphic name:', torch.cuda.get_device_name())
        # setting device on GPU if available, else CPU
        print('Using device:', device)

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)
            print(self.model)
            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # VAE만 학습시키기
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))
            #     if 'vae' not in name:
            #         param.requires_grad = False
            #     print('\t train: ' + '\t', param.requires_grad)

            # Tensorboard 주석처리 내가 했음
            self.writer = TensorboardWriter(self.config.log_dir)
コード例 #6
0
ファイル: solver.py プロジェクト: NoSyu/CDMM-B
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            if self.config.optimizer is None:
                # AdamW
                no_decay = ['bias', 'LayerNorm.weight']
                optimizer_grouped_parameters = [{
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.01
                }, {
                    'params': [
                        p for n, p in self.model.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    'weight_decay':
                    0.0
                }]
                self.optimizer = AdamW(optimizer_grouped_parameters,
                                       lr=self.config.learning_rate)
            else:
                self.optimizer = self.config.optimizer(
                    filter(lambda p: p.requires_grad, self.model.parameters()),
                    lr=self.config.learning_rate)
コード例 #7
0
    def __init__(
        self,
        caption_model: str,
        epochs: int,
        device: torch.device,
        word_map: Dict[str, int],
        rev_word_map: Dict[int, str],
        start_epoch: int,
        epochs_since_improvement: int,
        best_bleu4: float,
        train_loader: DataLoader,
        val_loader: DataLoader,
        encoder: nn.Module,
        decoder: nn.Module,
        encoder_optimizer: optim.Optimizer,
        decoder_optimizer: optim.Optimizer,
        loss_function: nn.Module,
        grad_clip: float,
        tau: float,
        fine_tune_encoder: bool,
        tensorboard: bool = False,
        log_dir: Optional[str] = None
    ) -> None:
        self.device = device  # GPU / CPU

        self.caption_model = caption_model
        self.epochs = epochs
        self.word_map = word_map
        self.rev_word_map = rev_word_map

        self.start_epoch = start_epoch
        self.epochs_since_improvement = epochs_since_improvement
        self.best_bleu4 = best_bleu4

        self.train_loader =  train_loader
        self.val_loader = val_loader
        self.encoder = encoder
        self.decoder = decoder
        self.encoder_optimizer = encoder_optimizer
        self.decoder_optimizer = decoder_optimizer
        self.loss_function = loss_function

        self.tau = tau
        self.grad_clip = grad_clip
        self.fine_tune_encoder = fine_tune_encoder

        self.print_freq = 100  # print training/validation stats every __ batches
        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)
コード例 #8
0
    def build(self, cuda=True):
        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

        if torch.cuda.is_available() and cuda:
            self.model.cuda()

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)
コード例 #9
0
    def __init__(self, model, criterion, metric_ftns, optimizer, config):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])

        # setup GPU device if available, move model into configured device
        self.device, self.device_ids = self._prepare_device(
            config['num_gpu'], config['main_device_id'], config['device_id'])
        self.model = model.cuda(self.device)
        if len(self.device_ids) > 1:
            self.model = torch.nn.DataParallel(model,
                                               device_ids=self.device_ids)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')
        self.add_graph = cfg_trainer.get('add_graph', False)

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        self.start_epoch = 1
        self.best_epoch = self.start_epoch
        self.checkpoint_dir = config.save_dir
        self.logger_dir = config.log_dir
        self.ner_type = config.config['experiment_name'].split('_')[-1]

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])

        if config.resume is not None:
            self._resume_checkpoint(config.resume)
コード例 #10
0
    def build(self, cuda=True):

        if self.model is None:
            self.model = getattr(models, self.config.model)(self.config)

            # orthogonal initialiation for hidden weights
            # input gate bias for GRUs
            if self.config.mode == 'train' and self.config.checkpoint is None:
                print('Parameter initiailization')
                for name, param in self.model.named_parameters():
                    if 'weight_hh' in name:
                        print('\t' + name)
                        nn.init.orthogonal_(param)

                    # bias_hh is concatenation of reset, input, new gates
                    # only set the input gate bias to 2.0
                    if 'bias_hh' in name:
                        print('\t' + name)
                        dim = int(param.size(0) / 3)
                        param.data[dim:2 * dim].fill_(2.0)

        # if torch.cuda.is_available() and cuda:
        #    self.model.cuda()

        if torch.cuda.is_available() and cuda:
            self.model = self.model.cuda()
        """
        if torch.cuda.device_count() > 1:
            device_ids = [0, 1, 2, 3]
            self.model = nn.DataParallel(self.model, device_ids=device_ids)
        """

        # Overview Parameters
        print('Model Parameters')
        for name, param in self.model.named_parameters():
            print('\t' + name + '\t', list(param.size()))

        if self.config.checkpoint:
            self.load_model(self.config.checkpoint)

        if self.is_train:
            self.writer = TensorboardWriter(self.config.logdir)
            self.optimizer = self.config.optimizer(
                filter(lambda p: p.requires_grad, self.model.parameters()),
                lr=self.config.learning_rate)
コード例 #11
0
    def build(self):

        # Build Modules
        self.linear_compress = nn.Linear(self.config.input_size,
                                         self.config.hidden_size).cuda()
        self.summarizer = Summarizer(input_size=self.config.hidden_size,
                                     hidden_size=self.config.hidden_size,
                                     num_layers=self.config.num_layers).cuda()
        self.discriminator = Discriminator(
            input_size=self.config.hidden_size,
            hidden_size=self.config.hidden_size,
            num_layers=self.config.num_layers).cuda()
        self.model = nn.ModuleList(
            [self.linear_compress, self.summarizer, self.discriminator])

        if self.config.mode == 'train':
            # Build Optimizers
            self.s_e_optimizer = optim.Adam(
                list(self.summarizer.s_lstm.parameters()) +
                list(self.summarizer.vae.e_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.d_optimizer = optim.Adam(
                list(self.summarizer.vae.d_lstm.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.lr)
            self.c_optimizer = optim.Adam(
                list(self.discriminator.parameters()) +
                list(self.linear_compress.parameters()),
                lr=self.config.discriminator_lr)

            self.model.train()
            # self.model.apply(apply_weight_norm)

            # Overview Parameters
            # print('Model Parameters')
            # for name, param in self.model.named_parameters():
            #     print('\t' + name + '\t', list(param.size()))

            # Tensorboard
            import ipdb
            ipdb.set_trace()
            self.writer = TensorboardWriter(self.config.log_dir)
コード例 #12
0
    def __init__(self,
                 num_epochs: int,
                 start_epoch: int,
                 train_loader: DataLoader,
                 model: nn.Module,
                 model_name: str,
                 loss_function: nn.Module,
                 optimizer,
                 lr_decay: float,
                 dataset_name: str,
                 word_map: Dict[str, int],
                 grad_clip=Optional[None],
                 print_freq: int = 100,
                 checkpoint_path: Optional[str] = None,
                 checkpoint_basename: str = 'checkpoint',
                 tensorboard: bool = False,
                 log_dir: Optional[str] = None) -> None:
        self.num_epochs = num_epochs
        self.start_epoch = start_epoch
        self.train_loader = train_loader

        self.model = model
        self.model_name = model_name
        self.loss_function = loss_function
        self.optimizer = optimizer
        self.lr_decay = lr_decay

        self.dataset_name = dataset_name
        self.word_map = word_map
        self.print_freq = print_freq
        self.grad_clip = grad_clip

        self.checkpoint_path = checkpoint_path
        self.checkpoint_basename = checkpoint_basename

        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)
コード例 #13
0
ファイル: train.py プロジェクト: NoAchache/TextBoxGan
    def __init__(self):

        self.batch_size = cfg.batch_size
        self.strategy = cfg.strategy
        self.max_steps = cfg.max_steps
        self.summary_steps_frequency = cfg.summary_steps_frequency
        self.image_summary_step_frequency = cfg.image_summary_step_frequency
        self.save_step_frequency = cfg.save_step_frequency
        self.log_dir = cfg.log_dir

        self.validation_step_frequency = cfg.validation_step_frequency
        self.tensorboard_writer = TensorboardWriter(self.log_dir)
        # set optimizer params
        self.g_opt = self.update_optimizer_params(cfg.g_opt)
        self.d_opt = self.update_optimizer_params(cfg.d_opt)
        self.pl_mean = tf.Variable(
            initial_value=0.0,
            name="pl_mean",
            trainable=False,
            synchronization=tf.VariableSynchronization.ON_READ,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )
        self.training_data_loader = TrainingDataLoader()
        self.validation_data_loader = ValidationDataLoader("validation_corpus.txt")
        self.model_loader = ModelLoader()
        # create model: model and optimizer must be created under `strategy.scope`
        (
            self.discriminator,
            self.generator,
            self.g_clone,
        ) = self.model_loader.initiate_models()

        # set optimizers
        self.d_optimizer = tf.keras.optimizers.Adam(
            self.d_opt["learning_rate"],
            beta_1=self.d_opt["beta1"],
            beta_2=self.d_opt["beta2"],
            epsilon=self.d_opt["epsilon"],
        )
        self.g_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_optimizer = tf.keras.optimizers.Adam(
            self.g_opt["learning_rate"],
            beta_1=self.g_opt["beta1"],
            beta_2=self.g_opt["beta2"],
            epsilon=self.g_opt["epsilon"],
        )
        self.ocr_loss_weight = cfg.ocr_loss_weight

        self.aster_ocr = AsterInferer()

        self.training_step = TrainingStep(
            self.generator,
            self.discriminator,
            self.aster_ocr,
            self.g_optimizer,
            self.ocr_optimizer,
            self.d_optimizer,
            self.g_opt["reg_interval"],
            self.d_opt["reg_interval"],
            self.pl_mean,
        )

        self.validation_step = ValidationStep(self.g_clone, self.aster_ocr)

        self.manager = self.model_loader.load_checkpoint(
            ckpt_kwargs={
                "d_optimizer": self.d_optimizer,
                "g_optimizer": self.g_optimizer,
                "ocr_optimizer": self.ocr_optimizer,
                "discriminator": self.discriminator,
                "generator": self.generator,
                "g_clone": self.g_clone,
                "pl_mean": self.pl_mean,
            },
            model_description="Full model",
            expect_partial=False,
            ckpt_dir=cfg.ckpt_dir,
            max_to_keep=cfg.num_ckpts_to_keep,
        )
コード例 #14
0
    def build(self):
        # Build Modules
        # self.device = torch.device('cuda:0,1')
        self.embedding = nn.Embedding(self.config.vocab_size,
                                      self.config.wemb_size,
                                      padding_idx=0)

        if True:
            weights_matrix = torch.FloatTensor(
                pickle.load(open(p.word_vec_pkl, 'rb')))
            self.embedding.from_pretrained(weights_matrix, freeze=False)
            self.embedding.weight.requires_grad = True

        self.w_hr_fw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])
        self.w_hr_bw = nn.ModuleList(self.config.num_layers * [
            nn.Linear(
                self.config.hidden_size, self.config.kwd_size, bias=False)
        ])

        self.w_wr = nn.Linear(self.config.wemb_size,
                              self.config.kwd_size,
                              bias=False)
        self.w_ho_fw = nn.Sequential(
            nn.Linear(self.config.hidden_size * self.config.num_layers,
                      self.config.vocab_size),
            #             nn.LogSoftmax(dim=-1)
        )
        self.w_ho_bw = nn.Linear(
            self.config.hidden_size * self.config.num_layers,
            self.config.vocab_size)
        self.sc_rnn_fw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.sc_rnn_bw = SCLSTM_MultiCell(self.config.num_layers,
                                          self.config.wemb_size,
                                          self.config.hidden_size,
                                          self.config.kwd_size,
                                          dropout=self.config.drop_rate)

        self.model = nn.ModuleList([
            self.w_hr_fw, self.w_hr_bw, self.w_wr, self.w_ho_fw, self.w_ho_bw,
            self.sc_rnn_fw, self.sc_rnn_bw
        ])

        self.criterion = nn.CrossEntropyLoss(reduction='none')

        with torch.no_grad():
            self.hc_list_init = (Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False),
                                 Variable(torch.zeros(self.config.num_layers,
                                                      self.config.batch_size,
                                                      self.config.hidden_size),
                                          requires_grad=False))

        #--- Init dirs for output ---
        self.current_time = datetime.now().strftime('%b%d_%H-%M-%S')
        if self.config.mode == 'train':
            # Overview Parameters
            print('Init Model Parameters')
            for name, param in self.model.named_parameters():
                print('\t' + name + '\t', list(param.size()))
                if param.data.ndimension() >= 2:
                    nn.init.xavier_uniform_(param.data)
                else:
                    nn.init.zeros_(param.data)

            # Tensorboard
            self.writer = TensorboardWriter(p.tb_dir + self.current_time)
            # Add emb-layer
            self.model.train()
            # create dir
            #             self.res_dir = p.result_path.format(p.dataname, self.current_time) # result dir
            self.cp_dir = p.check_point.format(
                p.dataname, self.current_time)  # checkpoint dir
            #             os.makedirs(self.res_dir)
            os.makedirs(self.cp_dir)

        #--- Setup output file ---
        self.out_file = open(
            p.out_result_dir.format(p.dataname, self.current_time), 'w')

        self.model.append(self.embedding)
        #         self.model.to(self.device)
        # Build Optimizers
        self.optimizer = optim.Adam(list(self.model.parameters()),
                                    lr=self.config.lr)
        print(self.model)
コード例 #15
0
import os
import yaml
from loguru import logger
from time import gmtime, strftime
from utils import TensorboardWriter

if not os.path.isdir('logs'):
    os.mkdir('logs')
current_time = strftime("%Y-%m-%d_%H:%M:%S", gmtime())
logger.add(f'logs/train_{current_time}.log')

# problem on macOS
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

with open('config.yaml') as f:
    config = yaml.safe_load(f)

tensorboard_writer = None
if 'USE_TENSORBOARD' in config and config['USE_TENSORBOARD']:
    tensorboard_writer = TensorboardWriter(f'runs/{current_time}')

logger.info(f'config loaded: {config}')
コード例 #16
0
def main():
    config = get_train_config()

    # device
    device, device_ids = setup_device(config.n_gpu)

    # tensorboard
    writer = TensorboardWriter(config.summary_dir, config.tensorboard)

    # metric tracker
    metric_names = ['loss', 'acc1', 'acc5']
    train_metrics = MetricTracker(*[metric for metric in metric_names],
                                  writer=writer)
    valid_metrics = MetricTracker(*[metric for metric in metric_names],
                                  writer=writer)

    # create model
    print("create model")
    model = VisionTransformer(image_size=(config.image_size,
                                          config.image_size),
                              patch_size=(config.patch_size,
                                          config.patch_size),
                              emb_dim=config.emb_dim,
                              mlp_dim=config.mlp_dim,
                              num_heads=config.num_heads,
                              num_layers=config.num_layers,
                              num_classes=config.num_classes,
                              attn_dropout_rate=config.attn_dropout_rate,
                              dropout_rate=config.dropout_rate)

    # load checkpoint
    if config.checkpoint_path:
        state_dict = load_checkpoint(config.checkpoint_path)
        if config.num_classes != state_dict['classifier.weight'].size(0):
            del state_dict['classifier.weight']
            del state_dict['classifier.bias']
            print("re-initialize fc layer")
            model.load_state_dict(state_dict, strict=False)
        else:
            model.load_state_dict(state_dict)
        print("Load pretrained weights from {}".format(config.checkpoint_path))

    # send model to device
    model = model.to(device)
    if len(device_ids) > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    # create dataloader
    print("create dataloaders")
    train_dataloader = eval("{}DataLoader".format(config.dataset))(
        data_dir=os.path.join(config.data_dir, config.dataset),
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        split='train')
    valid_dataloader = eval("{}DataLoader".format(config.dataset))(
        data_dir=os.path.join(config.data_dir, config.dataset),
        image_size=config.image_size,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        split='val')

    # training criterion
    print("create criterion and optimizer")
    criterion = nn.CrossEntropyLoss()

    # create optimizers and learning rate scheduler
    optimizer = torch.optim.SGD(params=model.parameters(),
                                lr=config.lr,
                                weight_decay=config.wd,
                                momentum=0.9)
    lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer=optimizer,
        max_lr=config.lr,
        pct_start=config.warmup_steps / config.train_steps,
        total_steps=config.train_steps)

    # start training
    print("start training")
    best_acc = 0.0
    epochs = config.train_steps // len(train_dataloader)
    for epoch in range(1, epochs + 1):
        log = {'epoch': epoch}

        # train the model
        model.train()
        result = train_epoch(epoch, model, train_dataloader, criterion,
                             optimizer, lr_scheduler, train_metrics, device)
        log.update(result)

        # validate the model
        model.eval()
        result = valid_epoch(epoch, model, valid_dataloader, criterion,
                             valid_metrics, device)
        log.update(**{'val_' + k: v for k, v in result.items()})

        # best acc
        best = False
        if log['val_acc1'] > best_acc:
            best_acc = log['val_acc1']
            best = True

        # save model
        save_model(config.checkpoint_dir, epoch, model, optimizer,
                   lr_scheduler, device_ids, best)

        # print logged informations to the screen
        for key, value in log.items():
            print('    {:15s}: {}'.format(str(key), value))