예제 #1
0
파일: main.py 프로젝트: agvikas/dcgan-tf
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--parameters', type=str, default='/home/mdo2/Documents/artgan/parameters.json', help='model parameters file')
    parser.add_argument('--mode', type=bool, default=True, help=' True for training or False for inference')
    args = parser.parse_args()

    with open(args.parameters, 'r') as f:
        parameters = json.load(f)
    parameters['mode'] = args.mode

    gen_feed = tf.placeholder(dtype=tf.float32, shape=(None, parameters['noise_length']), name='gen_feed')
    dis_feed = tf.placeholder(dtype=tf.float32, shape=(None, 64, 64, 3), name='dis_feed')
    dis_labels_real = tf.placeholder(dtype=tf.float32, shape=(None, 1), name='dis_labels_real')
    dis_labels_fake = tf.placeholder(dtype=tf.float32, shape=(None, 1), name='dis_labels_fake')
    dis_feed_cond = tf.placeholder(dtype=tf.bool, name='dis_feed_cond')

    model = architecture.Model(parameters)
    gen_output = model.generator(gen_feed)
    dis_out_real = model.discriminator(dis_feed)
    dis_out_fake = model.discriminator(gen_output, reuse=True)     

    if args.mode == True: #Training
        train(gen_feed, dis_feed, gen_output, dis_out_real, dis_out_fake, dis_labels_real, dis_labels_fake, parameters)
    elif args.mode == False: #Inference 
        infer(model, gen_feed)
예제 #2
0
def euler():
    data = np.load('euler.npy')
    mean_data = np.mean(data)
    std_data = np.std(data)
    data = (data.reshape([data.shape[0], -1]) - mean_data) / std_data
    n_features = data.shape[-1]
    batch_size = 50
    sequence_length = 240
    input_embed_size = 512
    n_neurons = 1024
    n_layers = 2
    n_gaussians = 5
    use_attention = True
    use_mdn = False
    # model_name = 'seq2seq.ckpt'
    restore_name = 'seq2seq.ckpt-508'

    train.infer(data=data,
                mean_data=mean_data,
                std_data=std_data,
                batch_size=batch_size,
                sequence_length=sequence_length,
                n_features=n_features,
                input_embed_size=input_embed_size,
                n_neurons=n_neurons,
                n_layers=n_layers,
                n_gaussians=n_gaussians,
                use_attention=use_attention,
                use_mdn=use_mdn,
                model_name=restore_name)
예제 #3
0
def classfication_handler():
    data = request.form
    table_id = int(data['id'])
    text = data['text']
    classfication_result = infer([text], model, tokenizer)
    item = {'id': table_id, 'result': str(classfication_result[0])}
    table.put_item(Item=item)
    return json.dumps({'code': 200})
 def _inference(self, cand):
     t0 = time.time()
     print('testing model {} ..........'.format(cand))
     recalculate_bn(self.model, cand, self.train_dataprovider)
     torch.cuda.empty_cache()
     recal_bn_time = time.time() - t0
     test_top1_acc, _ = infer(self.val_dataprovider, self.model, self.criterion, cand)
     testtime = time.time() - t0
     print('|=> valid: accuracy = {:.3f}%, total_test_time = {:.2f}s, recal_bn_time={:.2f}s, cand = {}'.format(test_top1_acc, testtime, recal_bn_time, cand))
     return test_top1_acc
예제 #5
0
def do_inference():
    data = np.load('euler.npy')
    data = data.reshape(data.shape[0], -1)
    data_mean = np.mean(data, axis=0)
    data_std = np.std(data, axis=0)
    idxs = np.where(data_std > 0)[0]
    data_mean = data_mean[idxs]
    data_std = data_std[idxs]
    data = (data[:, idxs] - data_mean) / data_std
    n_features = data.shape[-1]
    sequence_length = 60
    input_embed_size = None
    n_neurons = 1024
    n_layers = 3
    n_gaussians = 20
    use_attention = True
    use_mdn = True
    restore_name = 'seq2seq_20-gaussians_3x1024_60-sequence-length_epoch-999'
    batch_size = 1
    offset = 0
    source = data[offset:offset + sequence_length * batch_size, :].reshape(
        batch_size, sequence_length, -1)
    target = data[offset + sequence_length * batch_size:offset +
                  sequence_length * batch_size * 2, :].reshape(
                      batch_size, sequence_length, -1)

    res = train.infer(source=source,
                      target=target,
                      data_mean=data_mean,
                      data_std=data_std,
                      batch_size=batch_size,
                      sequence_length=sequence_length,
                      n_features=n_features,
                      input_embed_size=input_embed_size,
                      n_neurons=n_neurons,
                      n_layers=n_layers,
                      n_gaussians=n_gaussians,
                      use_attention=use_attention,
                      use_mdn=use_mdn,
                      model_name=restore_name)

    np.save('source.npy', res['source'])
    np.save('target.npy', res['target'])
    np.save('encoding.npy', res['encoding'])
    np.save('prediction.npy', res['prediction'])
예제 #6
0
def euler_v3():
    data = np.load('euler.npy')
    mean_data = np.mean(data)
    std_data = np.std(data)
    data = (data.reshape([data.shape[0], -1]) - mean_data) / std_data
    n_features = data.shape[-1]
    sequence_length = 120
    input_embed_size = None
    n_neurons = 512
    n_layers = 3
    n_gaussians = 10
    use_attention = False
    use_mdn = False
    model_name = 'seq2seq-v3.ckpt-429'

    hop_length = 60
    idxs = np.arange(0, len(data) - sequence_length * 2, hop_length)
    source = np.array([data[i:i + sequence_length, :] for i in idxs])
    target = np.array(
        [data[i + sequence_length:i + sequence_length * 2, :] for i in idxs])
    batch_size = len(idxs)

    res = train.infer(source=source,
                      target=target,
                      mean_data=mean_data,
                      std_data=std_data,
                      batch_size=batch_size,
                      sequence_length=sequence_length,
                      n_features=n_features,
                      input_embed_size=input_embed_size,
                      n_neurons=n_neurons,
                      n_layers=n_layers,
                      n_gaussians=n_gaussians,
                      use_attention=use_attention,
                      use_mdn=use_mdn,
                      model_name=model_name)
예제 #7
0
def run_model(config,
         seed=0,
         data_dir='./data',
         genotype_class='PCDARTS',
         num_epochs=20,
         batch_size=get('batch_size'),
         init_channels=get('init_channels'),
         train_criterion=torch.nn.CrossEntropyLoss,
         data_augmentations=None,
         save_model_str=None, **kwargs):
    """
    Training loop for configurableNet.
    :param model_config: network config (dict)
    :param data_dir: dataset path (str)
    :param num_epochs: (int)
    :param batch_size: (int)
    :param learning_rate: model optimizer learning rate (float)
    :param train_criterion: Which loss to use during training (torch.nn._Loss)
    :param model_optimizer: Which model optimizer to use during trainnig (torch.optim.Optimizer)
    :param data_augmentations: List of data augmentations to apply such as rescaling.
        (list[transformations], transforms.Composition[list[transformations]], None)
        If none only ToTensor is used
    :return:
    """


    # instantiate optimize
    
    if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

    gpu = 'cuda:0'
    np.random.seed(seed)
    torch.cuda.set_device(gpu)
    cudnn.benchmark = True
    torch.manual_seed(seed)
    cudnn.enabled=True
    torch.cuda.manual_seed(seed)
    logging.info('gpu device = %s' % gpu)
    logging.info("config = %s", config)

    if data_augmentations is None:
        # You can add any preprocessing/data augmentation you want here
        data_augmentations = transforms.ToTensor()
    elif isinstance(type(data_augmentations), list):
        data_augmentations = transforms.Compose(data_augmentations)
    elif not isinstance(data_augmentations, transforms.Compose):
        raise NotImplementedError

    train_dataset = K49(data_dir, True, data_augmentations)
    test_dataset = K49(data_dir, False, data_augmentations)
    # train_dataset = KMNIST(data_dir, True, data_augmentations)
    # test_dataset = KMNIST(data_dir, False, data_augmentations)
    # Make data batch iterable
    # Could modify the sampler to not uniformly random sample
    
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=batch_size,
                             shuffle=False)

    genotype = eval("genotypes.%s" % genotype_class)
    model = Network(init_channels, train_dataset.n_classes, config['n_conv_layers'], genotype)
    model = model.cuda()
    
    total_model_params = np.sum(p.numel() for p in model.parameters())

    logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

    criterion = train_criterion
    criterion = criterion.cuda()
    
    if config['optimizer'] == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), 
                                    lr=config['initial_lr'], 
                                    momentum=config['sgd_momentum'], 
                                    weight_decay=config['weight_decay'], 
                                    nesterov=config['nesterov'])
    else:
        optimizer = get('opti_dict')[config['optimizer']](model.parameters(), lr=config['initial_lr'], weight_decay=config['weight_decay'])
    
    if config['lr_scheduler'] == 'Cosine':
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
    elif config['lr_scheduler'] == 'Exponential':
        lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

    logging.info('Generated Network:')
    summary(model, (train_dataset.channels,
                    train_dataset.img_rows,
                    train_dataset.img_cols),
            device='cuda' if torch.cuda.is_available() else 'cpu')
    for epoch in range(num_epochs):
        lr_scheduler.step()
        logging.info('epoch %d lr %e', epoch, lr_scheduler.get_lr()[0])
        model.drop_path_prob = config['drop_path_prob'] * epoch / num_epochs

        train_acc, train_obj = train(train_loader, model, criterion, optimizer, grad_clip=config['grad_clip_value'])
        logging.info('train_acc %f', train_acc)

        test_acc, test_obj = infer(test_loader, model, criterion)
        logging.info('test_acc %f', test_acc)


    if save_model_str:
        # Save the model checkpoint, can be restored via "model = torch.load(save_model_str)"
        if os.path.exists(save_model_str):
            save_model_str += '_'.join(time.ctime())
        torch.save(model.state_dict(), save_model_str)
    
    return test_acc
예제 #8
0
    def compute(self, x, budget, config, **kwargs):
        """
        Get model with hyperparameters from config generated by get_configspace()
        """
        config = get_config_dictionary(x, config)
        print("config", config)
        if (len(config.keys())<len(x)):
            return 100
        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        gpu = 'cuda:0'
        np.random.seed(self.seed)
        torch.cuda.set_device(gpu)
        cudnn.benchmark = True
        torch.manual_seed(self.seed)
        cudnn.enabled=True
        torch.cuda.manual_seed(self.seed)
        logging.info('gpu device = %s' % gpu)
        logging.info("config = %s", config)

        genotype = eval("genotypes.%s" % 'PCDARTS')
        model = Network(self.init_channels, self.n_classes, config['n_conv_layers'], genotype)
        model = model.cuda()

        logging.info("param size = %fMB", utils.count_parameters_in_MB(model))

        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()
        
        if config['optimizer'] == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), 
                                        lr=config['initial_lr'], 
                                        momentum=0.9, 
                                        weight_decay=config['weight_decay'], 
                                        nesterov=True)
        else:
            optimizer = settings.opti_dict[config['optimizer']](model.parameters(), lr=config['initial_lr'])
        
        if config['lr_scheduler'] == 'Cosine':
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, int(budget))
        elif config['lr_scheduler'] == 'Exponential':
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

        
        indices = list(range(int(self.split*len(self.train_dataset))))
        valid_indices =  list(range(int(self.split*len(self.train_dataset)), len(self.train_dataset)))
        print("Training size=", len(indices))
        training_sampler = SubsetRandomSampler(indices)
        valid_sampler = SubsetRandomSampler(valid_indices)
        train_queue = torch.utils.data.DataLoader(dataset=self.train_dataset,
                                                batch_size=self.batch_size,
                                                sampler=training_sampler) 

        valid_queue = torch.utils.data.DataLoader(dataset=self.train_dataset, 
                                                batch_size=self.batch_size, 
                                                sampler=valid_sampler)


        for epoch in range(int(budget)):
            lr_scheduler.step()
            logging.info('epoch %d lr %e', epoch, lr_scheduler.get_lr()[0])
            model.drop_path_prob = config['drop_path_prob'] * epoch / int(budget)

            train_acc, train_obj = train(train_queue, model, criterion, optimizer, grad_clip=config['grad_clip_value'])
            logging.info('train_acc %f', train_acc)

            valid_acc, valid_obj = infer(valid_queue, model, criterion)
            logging.info('valid_acc %f', valid_acc)

        return valid_obj # Hyperband always minimizes, so we want to minimise the error, error = 1-acc
예제 #9
0
def main():
  if not torch.cuda.is_available():
    logger.info('no gpu device available')
    sys.exit(1)

  np.random.seed(args.seed)
  torch.cuda.set_device(args.gpu)
  cudnn.benchmark = True
  torch.manual_seed(args.seed)
  cudnn.enabled=True
  torch.cuda.manual_seed(args.seed)
  logger.info('gpu device = %d' % args.gpu)
  logger.info("args = %s", args)

  # # load the correct ops dictionary
  op_dict_to_load = "operations.%s" % args.ops
  logger.info('loading op dict: ' + str(op_dict_to_load))
  op_dict = eval(op_dict_to_load)

  # load the correct primitives list
  primitives_to_load = "genotypes.%s" % args.primitives
  logger.info('loading primitives:' + primitives_to_load)
  primitives = eval(primitives_to_load)
  logger.info('primitives: ' + str(primitives))

  genotype = eval("genotypes.%s" % args.arch)
  cnn_model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype, op_dict=op_dict, C_mid=args.mid_channels)
  if args.parallel:
    cnn_model = nn.DataParallel(cnn_model).cuda()
  else:
    cnn_model = cnn_model.cuda()

  logger.info("param size = %fMB", utils.count_parameters_in_MB(cnn_model))
  if args.flops:
    cnn_model.drop_path_prob = 0.0
    logger.info("flops = " + utils.count_model_flops(cnn_model, data_shape=[1, 3, 224, 224]))
    exit(1)

  criterion = nn.CrossEntropyLoss()
  criterion = criterion.cuda()
  criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
  criterion_smooth = criterion_smooth.cuda()

  optimizer = torch.optim.SGD(
    cnn_model.parameters(),
    args.learning_rate,
    momentum=args.momentum,
    weight_decay=args.weight_decay
    )

  traindir = os.path.join(args.data, 'train')
  validdir = os.path.join(args.data, 'val')
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  train_data = dset.ImageFolder(
    traindir,
    transforms.Compose([
      transforms.RandomResizedCrop(224),
      transforms.RandomHorizontalFlip(),
      autoaugment.ImageNetPolicy(),
      # transforms.ColorJitter(
      #   brightness=0.4,
      #   contrast=0.4,
      #   saturation=0.4,
      #   hue=0.2),
      transforms.ToTensor(),
      normalize,
    ]))
  valid_data = dset.ImageFolder(
    validdir,
    transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      normalize,
    ]))

  train_queue = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=8)

  valid_queue = torch.utils.data.DataLoader(
    valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=8)

  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)

  prog_epoch = tqdm(range(args.epochs), dynamic_ncols=True)
  best_valid_acc = 0.0
  best_epoch = 0
  best_stats = {}
  best_acc_top1 = 0
  weights_file = os.path.join(args.save, 'weights.pt')
  for epoch in prog_epoch:
    scheduler.step()
    cnn_model.drop_path_prob = args.drop_path_prob * epoch / args.epochs

    train_acc, train_obj = train.train(args, train_queue, cnn_model, criterion, optimizer)
    stats = train.infer(args, valid_queue, cnn_model, criterion)

    is_best = False
    if stats['valid_acc'] > best_valid_acc:
      # new best epoch, save weights
      utils.save(cnn_model, weights_file)
      best_epoch = epoch
      best_valid_acc = stats['valid_acc']

      best_stats = stats
      best_stats['lr'] = scheduler.get_lr()[0]
      best_stats['epoch'] = best_epoch
      best_train_loss = train_obj
      best_train_acc = train_acc
      is_best = True

    logger.info('epoch, %d, train_acc, %f, valid_acc, %f, train_loss, %f, valid_loss, %f, lr, %e, best_epoch, %d, best_valid_acc, %f, ' + utils.dict_to_log_string(stats),
                epoch, train_acc, stats['valid_acc'], train_obj, stats['valid_loss'], scheduler.get_lr()[0], best_epoch, best_valid_acc)
    checkpoint = {
          'epoch': epoch,
          'state_dict': cnn_model.state_dict(),
          'best_acc_top1': best_valid_acc,
          'optimizer' : optimizer.state_dict(),
    }
    checkpoint.update(stats)
    utils.save_checkpoint(stats, is_best, args.save)

  best_epoch_str = utils.dict_to_log_string(best_stats, key_prepend='best_')
  logger.info(best_epoch_str)
  logger.info('Training of Final Model Complete! Save dir: ' + str(args.save))
예제 #10
0
    def compute(self, config, budget, *args, **kwargs):
        """
        Get model with hyperparameters from config generated by get_configspace()
        """
        if not torch.cuda.is_available():
            logging.info('no gpu device available')
            sys.exit(1)

        logging.info(f'Running config for {budget} epochs')
        gpu = 'cuda:0'
        np.random.seed(self.seed)
        torch.cuda.set_device(gpu)
        cudnn.benchmark = True
        torch.manual_seed(self.seed)
        cudnn.enabled = True
        torch.cuda.manual_seed(self.seed)
        logging.info('gpu device = %s' % gpu)
        logging.info("config = %s", config)

        ensemble_model = EnsembleModel(self.trained_models,
                                       dense_units=config['dense_units'],
                                       out_size=self.train_dataset.n_classes)
        ensemble_model = ensemble_model.cuda()

        logging.info("param size = %fMB",
                     utils.count_parameters_in_MB(ensemble_model))

        criterion = nn.CrossEntropyLoss()
        criterion = criterion.cuda()

        if config['optimizer'] == 'sgd':
            optimizer = torch.optim.SGD(ensemble_model.parameters(),
                                        lr=config['initial_lr'],
                                        momentum=config['sgd_momentum'],
                                        weight_decay=config['weight_decay'],
                                        nesterov=config['nesterov'])
        else:
            optimizer = get('opti_dict')[config['optimizer']](
                ensemble_model.parameters(),
                lr=config['initial_lr'],
                weight_decay=config['weight_decay'])

        if config['lr_scheduler'] == 'Cosine':
            lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, int(budget))
        elif config['lr_scheduler'] == 'Exponential':
            lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,
                                                                  gamma=0.1)

        indices = list(
            np.random.randint(
                0,
                2 * len(self.train_dataset) // 3,
                size=len(self.train_dataset) //
                3))  #list(range(int(self.split*len(self.train_dataset))))
        valid_indices = list(
            np.random.randint(2 * len(self.train_dataset) // 3,
                              len(self.train_dataset),
                              size=len(self.train_dataset) // 6)
        )  #list(range(int(self.split*len(self.train_dataset)), len(self.train_dataset)))
        print("Training size=", len(indices))
        training_sampler = SubsetRandomSampler(indices)
        valid_sampler = SubsetRandomSampler(valid_indices)
        train_queue = torch.utils.data.DataLoader(dataset=self.train_dataset,
                                                  batch_size=self.batch_size,
                                                  sampler=training_sampler)

        valid_queue = torch.utils.data.DataLoader(dataset=self.train_dataset,
                                                  batch_size=self.batch_size,
                                                  sampler=valid_sampler)

        for epoch in range(int(budget)):
            logging.info('epoch %d lr %e', epoch, lr_scheduler.get_lr()[0])
            ensemble_model.drop_path_prob = config[
                'drop_path_prob'] * epoch / int(budget)

            train_acc, train_obj = ensemble_train(
                train_queue,
                ensemble_model,
                criterion,
                optimizer,
                grad_clip=config['grad_clip_value'])
            logging.info('train_acc %f', train_acc)
            lr_scheduler.step()

            valid_acc, valid_obj = infer(valid_queue, ensemble_model,
                                         criterion)
            logging.info('valid_acc %f', valid_acc)

        return ({
            'loss':
            valid_obj,  # Hyperband always minimizes, so we want to minimise the error, error = 1-accuracy
            'info':
            {}  # mandatory- can be used in the future to give more information
        })
예제 #11
0
def main():

    args = get_config()

    #---display model type---#
    print('-' * 80)
    print('# Training Conditional SpecGAN!') if args.conditional else print(
        '# Training SpecGAN!')
    print('-' * 80)

    #---make train dir---#
    if args.conditional: args.train_dir = args.train_dir + '_cond'
    if not os.path.isdir(args.train_dir):
        os.makedirs(args.train_dir)

    #---save args---#
    with open(os.path.join(args.train_dir, 'args.txt'), 'w') as f:
        f.write('\n'.join([
            str(k) + ',' + str(v)
            for k, v in sorted(vars(args).items(), key=lambda x: x[0])
        ]))

    #---make model kwarg dicts---#
    setattr(
        args, 'SpecGAN_g_kwargs', {
            'kernel_len': args.SpecGAN_kernel_len,
            'dim': args.SpecGAN_dim,
            'use_batchnorm': args.SpecGAN_batchnorm,
            'upsample': args.SpecGAN_genr_upsample,
            'initializer': args.SpecGAN_model_initializer,
        })
    setattr(
        args, 'SpecGAN_d_kwargs', {
            'kernel_len': args.SpecGAN_kernel_len,
            'dim': args.SpecGAN_dim,
            'use_batchnorm': args.SpecGAN_batchnorm,
            'initializer': args.SpecGAN_model_initializer,
        })

    #---collect path to data---#
    if args.mode == 'train' or args.mode == 'moments':
        fps = glob.glob(
            os.path.join(args.data_dir, args.data_tfrecord_prefix) +
            '*.tfrecord')

    #---load moments---#
    if args.mode != 'moments' and args.data_moments_file is not None:
        while True:
            try:
                print('# Moments: Loading existing moments file...')
                with open(os.path.join(args.train_dir, args.data_moments_file),
                          'rb') as f:
                    _mean, _std = pickle.load(f)
                    break
            except:
                print(
                    '# Moments: Failed to load, computing new moments file...')
                moments(fps, args)
        setattr(args, 'data_moments_mean', _mean)
        setattr(args, 'data_moments_std', _std)

    #---run selected mode---#

    #---run generate mode--#
    if args.mode == 'train':
        infer(args, cond=args.conditional)
        train(fps, args, cond=args.conditional)
    elif args.mode == 'generate':
        infer(args, cond=args.conditional)
        generate(args, cond=args.conditional)
    elif args.mode == 'moments':
        moments(fps, args)
    elif args.mode == 'preview':
        preview(args)
    elif args.mode == 'incept':
        incept(args)
    elif args.mode == 'infer':
        infer(args)
    else:
        raise NotImplementedError()
예제 #12
0
파일: main.py 프로젝트: ravinkohli/resnet
def model_train(model, config, criterion, trainloader, testloader, validloader,
                model_name):
    num_epochs = config['budget']
    success = False
    time_to_94 = None

    lrs = list()
    logging.info(f"weight decay:\t{config['weight_decay']}")
    logging.info(f"momentum :\t{config['momentum']}")

    base_optimizer = optim.SGD(model.parameters(),
                               lr=config['base_lr'],
                               weight_decay=config['weight_decay'],
                               momentum=config['momentum'])
    if config['swa']:
        optimizer = torchcontrib.optim.SWA(base_optimizer)

        # lr_scheduler = SWAResNetLR(optimizer, milestones=config['milestones'], schedule=config['schedule'], swa_start=config['swa_start'], swa_init_lr=config['swa_init_lr'], swa_step=config['swa_step'], base_lr=config['base_lr'])
    else:
        optimizer = base_optimizer
        # lr_scheduler = PiecewiseLinearLR(optimizer, milestones=config['milestones'], schedule=config['schedule'])

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, num_epochs)
    #lr_scheduler = PiecewiseLinearLR(optimizer, milestones=config['milestones'], schedule=config['schedule'])
    save_model_str = './models/'

    if not os.path.exists(save_model_str):
        os.mkdir(save_model_str)

    save_model_str += f'model_({datetime.datetime.now()})'
    if not os.path.exists(save_model_str):
        os.mkdir(save_model_str)

    summary_dir = f'{save_model_str}/summary'
    if not os.path.exists(summary_dir):
        os.mkdir(summary_dir)
    c = datetime.datetime.now()
    train_meter = AccuracyMeter(model_dir=summary_dir, name='train')
    test_meter = AccuracyMeter(model_dir=summary_dir, name='test')
    valid_meter = AccuracyMeter(model_dir=summary_dir, name='valid')

    for epoch in range(num_epochs):
        lr = lr_scheduler.get_lr()[0]
        lrs.append(lr)

        logging.info('epoch %d, lr %e', epoch, lr)

        train_acc, train_obj, time = train(trainloader, model, criterion,
                                           optimizer, model_name,
                                           config['grad_clip'],
                                           config['prefetch'])

        train_meter.update({
            'acc': train_acc,
            'loss': train_obj
        }, time.total_seconds())
        lr_scheduler.step()
        if config['swa'] and ((epoch + 1) >= config['swa_start']) and (
            (epoch + 1 - config['swa_start']) % config['swa_step'] == 0):
            optimizer.update_swa()
        valid_acc, valid_obj, time = infer(testloader,
                                           model,
                                           criterion,
                                           name=model_name,
                                           prefetch=config['prefetch'])
        valid_meter.update({
            'acc': valid_acc,
            'loss': valid_obj
        }, time.total_seconds())
        if valid_acc >= 94:
            success = True
            time_to_94 = train_meter.time
            logging.info(f'Time to reach 94% {time_to_94}')
        # wandb.log({"Test Accuracy":valid_acc, "Test Loss": valid_obj, "Train Accuracy":train_acc, "Train Loss": train_obj})

    a = datetime.datetime.now() - c
    if config['swa']:
        optimizer.swap_swa_sgd()
        optimizer.bn_update(trainloader, model)
    test_acc, test_obj, time = infer(testloader,
                                     model,
                                     criterion,
                                     name=model_name,
                                     prefetch=config['prefetch'])
    test_meter.update({
        'acc': test_acc,
        'loss': test_obj
    }, time.total_seconds())
    torch.save(model.state_dict(), f'{save_model_str}/state')
    # wandb.save('model.h5')
    train_meter.plot(save_model_str)
    valid_meter.plot(save_model_str)

    plt.plot(lrs)
    plt.title('LR vs epochs')
    plt.xlabel('Epochs')
    plt.ylabel('LR')
    plt.xticks(np.arange(0, num_epochs, 5))
    plt.savefig(f'{save_model_str}/lr_schedule.png')
    plt.close()

    device = get('device')
    device_name = cpuinfo.get_cpu_info(
    )['brand'] if device.type == 'cpu' else torch.cuda.get_device_name(0)
    total_time = round(a.total_seconds(), 2)
    logging.info(
        f'test_acc: {test_acc}, save_model_str:{save_model_str}, total time :{total_time} and device used {device_name}'
    )
    _, cnt, time = train_meter.get()
    time_per_step = round(time / cnt, 2)
    return_dict = {
        'test_acc': test_acc,
        'save_model_str': save_model_str,
        'training_time_per_step': time_per_step,
        'total_train_time': time,
        'total_time': total_time,
        'device_used': device_name,
        'train_acc': train_acc
    }
    if success:
        return_dict['time_to_94'] = time_to_94
    return return_dict, model