def test(local_save_dir):
    data_iterator, data_init_op, num_batch = datasets.data_loader(
        is_training=False)
    data = data_iterator.get_next()

    log = os.path.join(local_save_dir, 'log')
    if not os.path.exists(log):
        os.makedirs(log)
    logger = Logger(log + "/log", level=FLAGS.logger_level)

    with tf.device('/gpu:0'):
        model_fn = get_model_fn()
        model = model_fn(data, is_training=False)

    saver = tf.train.Saver()

    var_init_op = tf.group(tf.local_variables_initializer(),
                           tf.global_variables_initializer())

    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
    sess.run(var_init_op)
    sess.run(data_init_op)

    ckpt = os.path.join(local_save_dir, 'checkpoint')
    load_model_status, _ = utils.load(sess, saver, ckpt)
    if load_model_status:
        print("[*] Model restore success!")
    else:
        print("[*] Not find pretrained model!")

    eval_fn = get_eval_fn()
    eval_fn(model, sess, num_batch, logger)
def evaluation(local_save_dir, sess, logger):
    data_iterator, data_init_op, num_batch = datasets.data_loader(
        is_training=False)
    sess.run(data_init_op)
    data = data_iterator.get_next()

    data_split = [{} for _ in range(FLAGS.gpu_num)]
    for k, t in data.items():
        t_split = tf.split(t, FLAGS.gpu_num, axis=0)
        for i, t_small in enumerate(t_split):
            data_split[i][k] = t_small

    model_list = []
    for i in range(FLAGS.gpu_num):
        with tf.device('/gpu:%d' % i):
            model_fn = get_model_fn()
            model = model_fn(data_split[i], is_training=False)
            model_list.append(model)

    eval_fn = get_eval_fn()
    eval_fn(model_list, sess, num_batch, logger)

    return
コード例 #3
0
def eval(**args):
    """
    Evaluate selected model 
    Args:
        seed       (Int):        Integer indicating set seed for random state
        save_dir   (String):     Top level directory to generate results folder
        model      (String):     Name of selected model 
        dataset    (String):     Name of selected dataset  
        exp        (String):     Name of experiment 
        load_type  (String):     Keyword indicator to evaluate the testing or validation set
        pretrained (Int/String): Int/String indicating loading of random, pretrained or saved weights
        
    Return:
        None
    """

    print("\n############################################################################\n")
    print("Experimental Setup: ", args)
    print("\n############################################################################\n")

    d          = datetime.datetime.today()
    date       = d.strftime('%Y%m%d-%H%M%S')
    result_dir = os.path.join(args['save_dir'], args['model'], '_'.join((args['dataset'],args['exp'],date)))
    log_dir    = os.path.join(result_dir, 'logs')
    save_dir   = os.path.join(result_dir, 'checkpoints')

    if not args['debug']:
        os.makedirs(result_dir, exist_ok=True)
        os.makedirs(log_dir,    exist_ok=True) 
        os.makedirs(save_dir,   exist_ok=True) 

        # Save copy of config file
        with open(os.path.join(result_dir, 'config.yaml'),'w') as outfile:
            yaml.dump(args, outfile, default_flow_style=False)

        # Tensorboard Element
        writer = SummaryWriter(log_dir)

    # Check if GPU is available (CUDA)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load Network
    model = create_model_object(**args).to(device)

    # Load Data
    loader = data_loader(**args, model_obj=model)

    if args['load_type'] == 'train_val':
        eval_loader = loader['valid']

    elif args['load_type'] == 'train':
        eval_loader = loader['train']

    elif args['load_type'] == 'test':
        eval_loader  = loader['test'] 

    else:
        sys.exit('load_type must be valid or test for eval, exiting')

    # END IF

    if isinstance(args['pretrained'], str):
        ckpt = load_checkpoint(args['pretrained'])
        model.load_state_dict(ckpt)

    # Training Setup
    params     = [p for p in model.parameters() if p.requires_grad]

    acc_metric = Metrics(**args, result_dir=result_dir, ndata=len(eval_loader.dataset))
    acc = 0.0

    # Setup Model To Evaluate 
    model.eval()

    with torch.no_grad():
        for step, data in enumerate(eval_loader):
            x_input     = data['data']
            annotations = data['annots']

            if isinstance(x_input, torch.Tensor):
                outputs = model(x_input.to(device))
            else:
                for i, item in enumerate(x_input):
                    if isinstance(item, torch.Tensor):
                        x_input[i] = item.to(device)
                outputs = model(*x_input)

            # END IF


            acc = acc_metric.get_accuracy(outputs, annotations)

            if step % 100 == 0:
                print('Step: {}/{} | {} acc: {:.4f}'.format(step, len(eval_loader), args['load_type'], acc))

    print('Accuracy of the network on the {} set: {:.3f} %\n'.format(args['load_type'], 100.*acc))

    if not args['debug']:
        writer.add_scalar(args['dataset']+'/'+args['model']+'/'+args['load_type']+'_accuracy', 100.*acc)
        # Close Tensorboard Element
        writer.close()
コード例 #4
0
ファイル: train.py プロジェクト: yxl502/text-classify
import torch
import torch.nn as nn
from torch import optim
from models import Model
from datasets import data_loader, text_CLS
from configs import Config

cfg = Config()

data_path = 'sources/weibo_senti_100k.csv'
data_stop_path = 'sources/hit_stopword'
dict_path = 'sources/dict'
dataset = text_CLS(dict_path, data_path, data_stop_path)
train_dataloader = data_loader(dataset, cfg)
cfg.pad_size = dataset.max_len_seq

model_text_cls = Model(config=cfg)
model_text_cls.to(cfg.device)

loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_text_cls.parameters(), lr=cfg.learn_rate)

for epoch in range(cfg.num_epochs):
    for i, batch in enumerate(train_dataloader):
        label, data = batch
        data = torch.tensor(data, dtype=torch.int64).to(cfg.device)
        label = torch.tensor(label, dtype=torch.int64).to(cfg.device)

        optimizer.zero_grad()
        pred = model_text_cls.forward(data)
        loss_val = loss_func(pred, label)
コード例 #5
0
def main(args=None):
    if args is None:
        args = get_parameter()

    if args.dataset == 'dali' and not dali_enable:
        args.case = args.case.replace('dali', 'imagenet')
        args.dataset = 'imagenet'
        args.workers = 12

    # log_dir
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    model_arch = args.model
    model_name = model_arch
    if args.evaluate:
        log_suffix = 'eval-' + model_arch + '-' + args.case
    else:
        log_suffix = model_arch + '-' + args.case
    utils.setup_logging(os.path.join(args.log_dir, log_suffix + '.txt'),
                        resume=args.resume)

    logging.info("current folder: %r", os.getcwd())
    logging.info("alqnet plugins: %r", plugin_enable)
    logging.info("apex available: %r", apex_enable)
    logging.info("dali available: %r", dali_enable)
    for x in vars(args):
        logging.info("config %s: %r", x, getattr(args, x))

    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and len(args.device_ids) > 0:
        args.device_ids = [
            x for x in args.device_ids
            if x < torch.cuda.device_count() and x >= 0
        ]
        if len(args.device_ids) == 0:
            args.device_ids = None
        else:
            logging.info("training on %d gpu", len(args.device_ids))
    else:
        args.device_ids = None

    if args.device_ids is not None:
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True  #https://github.com/pytorch/pytorch/issues/8019
    else:
        logging.info(
            "no gpu available, try CPU version, lots of functions limited")
        #return

    if model_name in models.model_zoo:
        model, args = models.get_model(args)
    else:
        logging.error("model(%s) not support, available models: %r" %
                      (model_name, models.model_zoo))
        return
    criterion = nn.CrossEntropyLoss()
    if 'label-smooth' in args.keyword:
        criterion_smooth = utils.CrossEntropyLabelSmooth(
            args.num_classes, args.label_smooth)

    # load policy for initial phase
    models.policy.deploy_on_init(model, getattr(args, 'policy', ''))
    # load policy for epoch updating
    epoch_policies = models.policy.read_policy(getattr(args, 'policy', ''),
                                               section='epoch')
    # print model
    logging.info("models: %r" % model)
    logging.info("epoch_policies: %r" % epoch_policies)

    utils.check_folder(args.weights_dir)
    args.weights_dir = os.path.join(args.weights_dir, model_name)
    utils.check_folder(args.weights_dir)
    args.resume_file = os.path.join(args.weights_dir,
                                    args.case + "-" + args.resume_file)
    args.pretrained = os.path.join(args.weights_dir, args.pretrained)
    epoch = 0
    lr = args.lr
    best_acc = 0
    scheduler = None
    checkpoint = None
    # resume training
    if args.resume:
        if utils.check_file(args.resume_file):
            logging.info("resuming from %s" % args.resume_file)
            if torch.cuda.is_available():
                checkpoint = torch.load(args.resume_file)
            else:
                checkpoint = torch.load(args.resume_file, map_location='cpu')
            if 'epoch' in checkpoint:
                epoch = checkpoint['epoch']
                logging.info("resuming ==> last epoch: %d" % epoch)
                epoch = epoch + 1
                logging.info("updating ==> epoch: %d" % epoch)
            if 'best_acc' in checkpoint:
                best_acc = checkpoint['best_acc']
                logging.info("resuming ==> best_acc: %f" % best_acc)
            if 'learning_rate' in checkpoint:
                lr = checkpoint['learning_rate']
                logging.info("resuming ==> learning_rate: %f" % lr)
            if 'state_dict' in checkpoint:
                utils.load_state_dict(model, checkpoint['state_dict'])
                logging.info("resumed from %s" % args.resume_file)
        else:
            logging.info("warning: *** resume file not exists({})".format(
                args.resume_file))
            args.resume = False
    else:
        if utils.check_file(args.pretrained):
            logging.info("load pretrained from %s" % args.pretrained)
            if torch.cuda.is_available():
                checkpoint = torch.load(args.pretrained)
            else:
                checkpoint = torch.load(args.pretrained, map_location='cpu')
            logging.info("load pretrained ==> last epoch: %d" %
                         checkpoint.get('epoch', 0))
            logging.info("load pretrained ==> last best_acc: %f" %
                         checkpoint.get('best_acc', 0))
            logging.info("load pretrained ==> last learning_rate: %f" %
                         checkpoint.get('learning_rate', 0))
            #if 'learning_rate' in checkpoint:
            #    lr = checkpoint['learning_rate']
            #    logging.info("resuming ==> learning_rate: %f" % lr)
            try:
                utils.load_state_dict(
                    model,
                    checkpoint.get('state_dict',
                                   checkpoint.get('model', checkpoint)))
            except RuntimeError as err:
                logging.info("Loading pretrained model failed %r" % err)
        else:
            logging.info(
                "no pretrained file exists({}), init model with default initlizer"
                .format(args.pretrained))

    if args.device_ids is not None:
        torch.cuda.set_device(args.device_ids[0])
        if not isinstance(model, nn.DataParallel) and len(args.device_ids) > 1:
            model = nn.DataParallel(model, args.device_ids).cuda()
        else:
            model = model.cuda()
        criterion = criterion.cuda()
        if 'label-smooth' in args.keyword:
            criterion_smooth = criterion_smooth.cuda()

    if 'label-smooth' in args.keyword:
        train_criterion = criterion_smooth
    else:
        train_criterion = criterion

    # move after to_cuda() for speedup
    if args.re_init and not args.resume:
        for m in model.modules():
            if hasattr(m, 'init_after_load_pretrain'):
                m.init_after_load_pretrain()

    # dataset
    data_path = args.root
    dataset = args.dataset
    logging.info("loading dataset with batch_size {} and val-batch-size {}. "
                 "dataset: {}, resolution: {}, path: {}".format(
                     args.batch_size, args.val_batch_size, dataset,
                     args.input_size, data_path))

    if args.val_batch_size < 1:
        val_loader = None
    else:
        if args.evaluate:
            val_batch_size = (args.batch_size // 100) * 100
            if val_batch_size > 0:
                args.val_batch_size = val_batch_size
            logging.info("update val_batch_size to %d in evaluate mode" %
                         args.val_batch_size)
        val_loader = datasets.data_loader(args.dataset)('val', args)

    if args.evaluate and val_loader is not None:
        if args.fp16 and torch.backends.cudnn.enabled and apex_enable and args.device_ids is not None:
            logging.info("training with apex fp16 at opt_level {}".format(
                args.opt_level))
        else:
            args.fp16 = False
            logging.info("training without apex")

        if args.fp16:
            optimizer = torch.optim.Adam(model.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)  #
            model, optimizer = amp.initialize(model,
                                              optimizer,
                                              opt_level=args.opt_level)

        logging.info("evaluate the dataset on pretrained model...")
        result = validate(val_loader, model, criterion, args)
        top1, top5, loss = result
        logging.info('evaluate accuracy on dataset: top1(%f) top5(%f)' %
                     (top1, top5))
        return

    train_loader = datasets.data_loader(args.dataset)('train', args)
    if isinstance(train_loader, torch.utils.data.dataloader.DataLoader):
        train_length = len(train_loader)
    else:
        train_length = getattr(train_loader, '_size', 0) / getattr(
            train_loader, 'batch_size', 1)

    # sample several iteration / epoch to calculate the initial value of quantization parameters
    if args.stable_epoch > 0 and args.stable <= 0:
        args.stable = train_length * args.stable_epoch
        logging.info("update stable: %d" % args.stable)

    # fix learning rate at the beginning to warmup
    if args.warmup_epoch > 0 and args.warmup <= 0:
        args.warmup = train_length * args.warmup_epoch
        logging.info("update warmup: %d" % args.warmup)

    params_dict = dict(model.named_parameters())
    params = []
    quant_wrapper = []
    for key, value in params_dict.items():
        #print(key)
        if 'quant_weight' in key and 'quant_weight' in args.custom_lr_list:
            to_be_quant = key.split('.quant_weight')[0] + '.weight'
            if to_be_quant not in quant_wrapper:
                quant_wrapper += [to_be_quant]
    if len(quant_wrapper) > 0 and args.verbose:
        logging.info("quant_wrapper: {}".format(quant_wrapper))

    for key, value in params_dict.items():
        shape = value.shape
        custom_hyper = dict()
        custom_hyper['params'] = value
        if value.requires_grad == False:
            continue

        found = False
        for i in args.custom_decay_list:
            if i in key and len(i) > 0:
                found = True
                break
        if found:
            custom_hyper['weight_decay'] = args.custom_decay
        elif (not args.decay_small and args.no_decay_small) and (
            (len(shape) == 4 and shape[1] == 1) or (len(shape) == 1)):
            custom_hyper['weight_decay'] = 0.0

        found = False
        for i in args.custom_lr_list:
            if i in key and len(i) > 0:
                found = True
                break
        if found:
            #custom_hyper.setdefault('lr_constant', args.custom_lr) # 2019.11.25
            custom_hyper['lr'] = args.custom_lr
        elif key in quant_wrapper:
            custom_hyper.setdefault('lr_constant', args.custom_lr)
            custom_hyper['lr'] = args.custom_lr

        params += [custom_hyper]

        if 'debug' in args.keyword:
            logging.info("{}, decay {}, lr {}, constant {}".format(
                key, custom_hyper.get('weight_decay', "default"),
                custom_hyper.get('lr', "default"),
                custom_hyper.get('lr_constant', "No")))

    optimizer = None
    if args.optimizer == "ADAM":
        optimizer = torch.optim.Adam(params,
                                     lr=args.lr,
                                     weight_decay=args.weight_decay)

    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

    if args.resume and checkpoint is not None:
        try:
            optimizer.load_state_dict(checkpoint['optimizer'])
        except RuntimeError as error:
            logging.info("Restore optimizer state failed %r" % error)

    if args.fp16 and torch.backends.cudnn.enabled and apex_enable and args.device_ids is not None:
        logging.info("training with apex fp16 at opt_level {}".format(
            args.opt_level))
    else:
        args.fp16 = False
        logging.info("training without apex")

    if args.sync_bn:
        logging.info("sync_bn to be supported, currently not yet")

    if args.fp16:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        if args.resume and checkpoint is not None:
            try:
                amp.load_state_dict(checkpoint['amp'])
            except RuntimeError as error:
                logging.info("Restore amp state failed %r" % error)

    # start tensorboard as late as possible
    if args.tensorboard and not args.evaluate:
        tb_log = os.path.join(args.log_dir, log_suffix)
        args.tensorboard = SummaryWriter(tb_log,
                                         filename_suffix='.' + log_suffix)
    else:
        args.tensorboard = None

    logging.info("start to train network " + model_name + ' with case ' +
                 args.case)
    while epoch < (args.epochs + args.extra_epoch):
        if 'proxquant' in args.keyword:
            if args.proxquant_step < 10:
                if args.lr_policy in ['sgdr', 'sgdr_step', 'custom_step']:
                    index = len([x for x in args.lr_custom_step if x <= epoch])
                    for m in model.modules():
                        if hasattr(m, 'prox'):
                            m.prox = 1.0 - 1.0 / args.proxquant_step * (index +
                                                                        1)
            else:
                for m in model.modules():
                    if hasattr(m, 'prox'):
                        m.prox = 1.0 - 1.0 / args.proxquant_step * epoch
                        if m.prox < 0:
                            m.prox = 0
        if epoch < args.epochs:
            lr, scheduler = utils.setting_learning_rate(
                optimizer, epoch, train_length, checkpoint, args, scheduler)
        if lr is None:
            logging.info('lr is invalid at epoch %d' % epoch)
            return
        else:
            logging.info('[epoch %d]: lr %e', epoch, lr)

        loss = 0
        top1, top5, eloss = 0, 0, 0
        is_best = top1 > best_acc
        # leverage policies on epoch
        models.policy.deploy_on_epoch(model,
                                      epoch_policies,
                                      epoch,
                                      optimizer=optimizer,
                                      verbose=logging.info)

        if 'lr-test' not in args.keyword:  # otherwise only print the learning rate in each epoch
            # training
            loss = train(train_loader, model, train_criterion, optimizer, args,
                         scheduler, epoch, lr)
            #for i in range(train_length):
            #  scheduler.step()
            logging.info('[epoch %d]: train_loss %.3f' % (epoch, loss))

            # validate
            top1, top5, eloss = 0, 0, 0
            top1, top5, eloss = validate(val_loader, model, criterion, args)
            is_best = top1 > best_acc
            if is_best:
                best_acc = top1
            logging.info('[epoch %d]: test_acc %f %f, best top1: %f, loss: %f',
                         epoch, top1, top5, best_acc, eloss)

        if args.tensorboard is not None:
            args.tensorboard.add_scalar(log_suffix + '/train-loss', loss,
                                        epoch)
            args.tensorboard.add_scalar(log_suffix + '/eval-top1', top1, epoch)
            args.tensorboard.add_scalar(log_suffix + '/eval-top5', top5, epoch)
            args.tensorboard.add_scalar(log_suffix + '/lr', lr, epoch)

        utils.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler':
                None if scheduler is None else scheduler.state_dict(),
                'best_acc': best_acc,
                'learning_rate': lr,
                'amp': None if not args.fp16 else amp.state_dict(),
            }, is_best, args)

        epoch = epoch + 1
        if epoch == 1:
            logging.info(utils.gpu_info())
def train(local_save_dir):
    log = os.path.join(local_save_dir, 'log')
    if not os.path.exists(log):
        os.makedirs(log)
    logger = Logger(log + "/log", level=FLAGS.logger_level)

    with tf.device('cpu:0'):
        data_iterator, data_init_op, num_batch = datasets.data_loader(
            is_training=True)
        data = data_iterator.get_next()

        data_split = [{} for _ in range(FLAGS.gpu_num)]
        for k, t in data.items():
            t_split = tf.split(t, FLAGS.gpu_num, axis=0)
            for i, t_small in enumerate(t_split):
                data_split[i][k] = t_small

        optimizer = tf.train.MomentumOptimizer(FLAGS.base_lr, 0.9)

        grads = []
        display_losses = []
        for i in range(FLAGS.gpu_num):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%d' % i):
                    model_fn = get_model_fn()
                    model = model_fn(data_split[i],
                                     is_training=True)

                    grads_sub = []
                    for d in model.compute_gradients_losses:
                        grads_sub += optimizer.compute_gradients(
                            loss=d['value'], var_list=d['var_list'])
                    grads.append(grads_sub)

                display_losses += model.display_losses

        grads = utils.average_gradients(grads)

        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            train_op = optimizer.apply_gradients(grads)

        var_init_op = tf.group(tf.local_variables_initializer(),
                               tf.global_variables_initializer())

        saver = tf.train.Saver(max_to_keep=5)

        sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True))
        sess.run(var_init_op)
        sess.run(data_init_op)

        print(tf.trainable_variables())

        ckpt = os.path.join(local_save_dir, 'checkpoint')
        load_model_status, global_step = utils.load(sess, saver, ckpt)
        if load_model_status:
            iter_num = global_step
            start_epoch = global_step // num_batch
            print("[*] Model restore success!")
        else:
            iter_num = 0
            start_epoch = 0
            print("[*] Not find pretrained model!")

        start = time.time()
        for epoch_id in range(start_epoch, FLAGS.epochs):
            for batch_id in range(num_batch):
                _, losses_eval = sess.run([train_op, display_losses])

                end = time.time()

                losses_dict = {}
                for d in losses_eval:
                    if d['name'] in losses_dict.keys():
                        losses_dict[d['name']] += [d['value']]
                    else:
                        losses_dict[d['name']] = [d['value']]

                log = "Epoch: [%2d] [%4d/%4d] time: %s | " % (
                    epoch_id+1, batch_id+1, num_batch,
                    str(timedelta(seconds=end-start))[0:10])
                for k, v in losses_dict.items():
                    k = k.decode("utf-8")
                    log += "%s: %.6f " % (k, np.mean(v))
                logger.logger.info(log)
                iter_num += 1

            logger.logger.info(log)

            if np.mod(epoch_id + 1, FLAGS.save_every_epoch) == 0:
                utils.save(sess, saver, iter_num, ckpt)
            if np.mod(epoch_id + 1, FLAGS.eval_every_epoch) == 0:
                evaluation(local_save_dir, sess, logger)

        print("[*] Finish training.")
コード例 #7
0
def main():
    args = get_parameter()
    cfg = None

    # log_dir
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)

    if isinstance(args.model, str):
        model_arch = args.model
        model_name = model_arch
    else:
        model_name = args.model['arch']
        model_arch = args.model['base'] + '-' + args.model['arch']

    if args.evaluate:
        log_suffix = model_arch + '-eval-' + args.case
    else:
        log_suffix = model_arch + '-' + args.case
    utils.setup_logging(os.path.join(args.log_dir, log_suffix + '.txt'),
                        resume=args.resume)

    # tensorboard
    if args.tensorboard and not args.evaluate:
        args.tensorboard = SummaryWriter(args.log_dir,
                                         filename_suffix='.' + log_suffix)
    else:
        args.tensorboard = None

    torch.manual_seed(args.seed)
    if torch.cuda.is_available() and len(args.device_ids) > 0:
        args.device_ids = [
            x for x in args.device_ids
            if x < torch.cuda.device_count() and x >= 0
        ]
        if len(args.device_ids) == 0:
            args.device_ids = None
    else:
        args.device_ids = None

    if args.device_ids is not None:
        torch.cuda.manual_seed_all(args.seed)
        cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True  #https://github.com/pytorch/pytorch/issues/8019

    logging.info("device_ids: %s" % args.device_ids)
    logging.info("no_decay_small: %r" % args.no_decay_small)
    logging.info("optimizer: %r" % args.optimizer)
    #logging.info(type(args.lr_custom_step))
    #if type(args.lr_custom_step) is str:
    args.lr_custom_step = [int(x) for x in args.lr_custom_step.split(',')]
    logging.info("lr_custom_step: %r" % args.lr_custom_step)

    if model_name in models.model_zoo:
        config = dict()
        for i in args.keyword.split(","):
            i = i.strip()
            logging.info('get keyword %s' % i)
            config[i] = True
        model = models.get_model(args, **config)
    else:
        logging.error("model(%s) not support, available models: %r" %
                      (model_name, models.model_zoo))
        return
    criterion = models.get_loss_function(args)

    utils.check_folder(args.weights_dir)
    args.weights_dir = os.path.join(args.weights_dir, model_name)
    utils.check_folder(args.weights_dir)
    args.resume_file = os.path.join(args.weights_dir,
                                    args.case + "-" + args.resume_file)
    args.pretrained = os.path.join(args.weights_dir, args.pretrained)
    epoch = 0
    best_acc = 0
    # resume training
    if args.resume:
        logging.info("resuming from %s" % args.resume_file)
        checkpoint = torch.load(args.resume_file)
        epoch = checkpoint['epoch']
        logging.info("resuming ==> last epoch: %d" % epoch)
        epoch = epoch + 1
        best_acc = checkpoint['best_acc']
        logging.info("resuming ==> best_acc: %f" % best_acc)
        utils.load_state_dict(model, checkpoint['state_dict'])
        logging.info("resumed from %s" % args.resume_file)
    else:
        if utils.check_file(args.pretrained):
            logging.info("resuming from %s" % args.pretrained)
            checkpoint = torch.load(args.pretrained)
            logging.info("resuming ==> last epoch: %d" % checkpoint['epoch'])
            logging.info("resuming ==> last best_acc: %f" %
                         checkpoint['best_acc'])
            logging.info("resuming ==> last learning_rate: %f" %
                         checkpoint['learning_rate'])
            utils.load_state_dict(model, checkpoint['state_dict'])
        else:
            logging.info(
                "no pretrained file exists({}), init model with default initlizer"
                .format(args.pretrained))

    if args.device_ids is not None:
        torch.cuda.set_device(args.device_ids[0])
        if not isinstance(model, nn.DataParallel) and len(args.device_ids) > 1:
            model = nn.DataParallel(model, args.device_ids).cuda()
        else:
            model = model.cuda()
        #criterion = criterion.cuda()

    # dataset
    data_path = os.path.join(args.root, "VOCdevkit/VOC2012")
    args.sbd = os.path.join(args.root, args.sbd)
    dataset = args.dataset
    logging.info(
        "loading dataset with batch_size {} and val-batch-size {}. dataset {} path: {}"
        .format(args.batch_size, args.val_batch_size, dataset, data_path))
    data_loader = datasets.data_loader(args.dataset)

    if args.val_batch_size < 1:
        val_loader = None
    else:
        val_dataset = data_loader(data_path,
                                  split=args.val_split,
                                  img_size=(args.row, args.col),
                                  sbd_path=args.sbd)
        val_loader = torch.utils.data.DataLoader(
            dataset=val_dataset,
            batch_size=args.val_batch_size,
            shuffle=False,
            num_workers=args.workers,
            pin_memory=True)

    if args.evaluate and val_loader is not None:
        logging.info("evaluate the dataset on pretrained model...")
        score, class_iou = validate(val_loader, model, criterion, args)
        for k, v in score.items():
            logging.info("{}: {}".format(k, v))
        for k, v in class_iou.items():
            logging.info("{}: {}".format(k, v))
        return

    # Setup Augmentations
    augmentations = args.aug
    data_aug = datasets.get_composed_augmentations(augmentations)
    train_dataset = data_loader(data_path,
                                split=args.train_split,
                                img_size=(args.row, args.col),
                                sbd_path=args.sbd,
                                augmentations=data_aug)
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=True)

    params_dict = dict(model.named_parameters())
    params = []
    for key, value in params_dict.items():
        shape = value.shape
        if args.no_decay_small and ((len(shape) == 4 and shape[1] == 1) or
                                    (len(shape) == 1)):
            params += [{'params': value, 'weight_decay': 0}]
        else:
            params += [{'params': value}]

    optimizer = None
    if args.optimizer == "ADAM":
        optimizer = torch.optim.Adam(params, lr=args.lr)

    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(params,
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=args.nesterov)

    if args.resume:
        optimizer.load_state_dict(checkpoint['optimizer'])

    logging.info("start to train network " + model_name + ' with case ' +
                 args.case)
    while epoch < args.epochs:
        lr = utils.adjust_learning_rate(optimizer, epoch, args)
        logging.info('[epoch %d]: lr %e', epoch, lr)

        # training
        loss = train(train_loader, model, criterion, optimizer, args)
        logging.info('[epoch %d]: train_loss %.3f' % (epoch, loss))

        # validate
        score, class_iou = validate(val_loader, model, criterion, args)
        val_acc = score["Mean IoU : \t"]
        is_best = val_acc > best_acc
        if is_best:
            best_acc = val_acc
        logging.info('[epoch %d]: current acc: %f, best acc: %f', epoch,
                     val_acc, best_acc)

        if args.tensorboard is not None:
            args.tensorboard.add_scalar(log_suffix + '/train-loss', loss,
                                        epoch)
            args.tensorboard.add_scalar(log_suffix + '/eval-acc', val_acc,
                                        epoch)
            args.tensorboard.add_scalar(log_suffix + '/lr', lr, epoch)

        utils.save_checkpoint(
            {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_acc': best_acc,
                'learning_rate': lr,
            }, is_best, args)

        epoch = epoch + 1