def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer,
                       loss_fn, metrics, params, model_dir, restore_file=None):
    """Train the model and evaluate every epoch.

    Args:
        model: (torch.nn.Module) the neural network
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) - name of file to restore from (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0

    # learning rate schedulers for different models:
    if params.model_version == "resnet18":
        scheduler = StepLR(optimizer, step_size=150, gamma=0.1)
    # for cnn models, num_epoch is always < 100, so it's intentionally not using scheduler here
    elif params.model_version == "cnn":
        scheduler = StepLR(optimizer, step_size=100, gamma=0.2)

    for epoch in range(params.num_epochs):
     
        scheduler.step()
     
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params)        

        val_acc = val_metrics['accuracy']
        is_best = val_acc>=best_val_acc

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict' : optimizer.state_dict()},
                               is_best=is_best,
                               checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir, "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)
Example #2
0
def train_and_evaluate(model, optimizer, scheduler, params,
                       restore_file=None):
    """Train the model and evaluate every epoch."""
    # load args
    args = parser.parse_args()
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(params.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        # 读取checkpoint
        utils.load_checkpoint(restore_path, model, optimizer)

    # Load training data and val data
    dataloader = NERDataLoader(params)
    train_loader, val_loader = dataloader.load_data(mode='train')

    # patience stage
    best_val_f1 = 0.0
    patience_counter = 0

    for epoch in range(1, args.epoch_num + 1):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch, args.epoch_num))

        # 一个epoch的步数
        params.train_steps = len(train_loader)

        # Train for one epoch on training set
        train(model, train_loader, optimizer, params)

        # Evaluate for one epoch on training set and validation set
        train_metrics = evaluate(model, train_loader, params, mark='Train',
                                 verbose=True)  # Dict['loss', 'f1']
        val_metrics = evaluate(model, val_loader, params, mark='Val',
                               verbose=True)  # Dict['loss', 'f1']

        # lr_scheduler学习率递减 step
        scheduler.step()

        # 验证集f1-score
        val_f1 = val_metrics['f1']
        # 提升的f1-score
        improve_f1 = val_f1 - best_val_f1

        # Save weights of the network
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        optimizer_to_save = optimizer
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model_to_save.state_dict(),
                               'optim_dict': optimizer_to_save.state_dict()},
                              is_best=improve_f1 > 0,
                              checkpoint=params.model_dir)
        params.save(params.params_path / 'params.json')

        # stop training based params.patience
        if improve_f1 > 0:
            logging.info("- Found new best F1")
            best_val_f1 = val_f1
            if improve_f1 < params.patience:
                patience_counter += 1
            else:
                patience_counter = 0
        else:
            patience_counter += 1

        # Early stopping and logging best f1
        if (patience_counter > params.patience_num and epoch > params.min_epoch_num) or epoch == args.epoch_num:
            logging.info("Best val f1: {:05.2f}".format(best_val_f1))
            break
                  embedding_dim=128,
                  negative_sample=False)
    model = utils.cuda(model, args.gpu_id)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    ### For annealing the learning rate
    lambda1 = lambda lr: 0.99 * lr
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                     lr_lambda=lambda1)

    if not os.path.isdir(args.checkpoint_dir):
        os.makedirs(args.checkpoint_dir)

    try:
        ckpt = utils.load_checkpoint(args.checkpoint_dir + '/latest_model_' +
                                     args.dataset)
        start_epoch = ckpt['epoch']
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optimizer'])
    except:
        print(' [*] No checkpoint!')
        start_epoch = 0

    for epoch in range(start_epoch, args.epochs):

        optimizer.zero_grad()
        model.train()

        w = torch.cat((edge_index[0, :], edge_index[1, :]))
        c = torch.cat((edge_index[1, :], edge_index[0, :]))
Example #4
0
def train_and_evaluate(model: nn.Module,
                       train_loader: DataLoader,
                       test_loader: DataLoader,
                       optimizer: optim, loss_fn,
                       params: utils.Params,
                       restore_file: str = None) -> None:
    '''Train the model and evaluate every epoch.
    Args:
        model: (torch.nn.Module) the Deep AR model
        train_loader: load train data and labels
        test_loader: load test data and labels
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        params: (Params) hyperparameters
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    '''
    # reload weights from restore_file if specified
    restore_epoch = 0
    if restore_file is not None:
        restore_path = os.path.join(params.model_dir, restore_file + '.pth.tar')
        logger.info('Restoring parameters from {}'.format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)
        restore_epoch = int(restore_file[-2:].replace('_',''))+1
    logger.info('Restoring epoch: {}'.format(restore_epoch))
    logger.info('Begin training and evaluation')
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=25, verbose=True, delta=0.0001, folder=params.model_dir)
    
    if os.path.exists(os.path.join(params.model_dir, 'metrics_test_best_weights.json')):
        with open(os.path.join(params.model_dir, 'metrics_test_best_weights.json')) as json_file:
            best_test_ND = json.load(json_file)['ND']
            early_stopping.best_score = best_test_ND
    else:
        best_test_ND = float('inf')
        early_stopping.best_score = best_test_ND
    
    train_len = len(train_loader)
    ND_summary = np.zeros(params.num_epochs)
    loss_summary = np.zeros((train_len * params.num_epochs))
    
    for epoch in range(restore_epoch, params.num_epochs):
        logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs))
        loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(model, optimizer, loss_fn, train_loader,
                                                                        test_loader, params, epoch)
        test_metrics = evaluate(model, loss_fn, test_loader, params, epoch, sample=args.sampling)
#         if test_metrics['ND'] == float('nan'):
#             test_metrics['ND'] = 1000
#             print('NAN ')

#         elif test_metrics['ND'] == np.nan:
#             print('NAN ')
#             test_metrics['ND'] = 1000
        
        ND_summary[epoch] = test_metrics['ND'] ##################################'ND'
        is_best = ND_summary[epoch] <= best_test_ND

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict': optimizer.state_dict()},
                              epoch=epoch,
                              is_best=is_best,
                              checkpoint=params.model_dir)

        if is_best:
            logger.info('- Found new best ND') ############# 'ND'
            best_test_ND = ND_summary[epoch]
            best_json_path = os.path.join(params.model_dir, 'metrics_test_best_weights.json')
            utils.save_dict_to_json(test_metrics, best_json_path)

        logger.info('Current Best ND is: %.5f' % best_test_ND) ## 'ND'

        utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND', params.plot_dir)
        utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', params.plot_dir)

        last_json_path = os.path.join(params.model_dir, 'metrics_test_last_weights.json')
        utils.save_dict_to_json(test_metrics, last_json_path)
        
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        logger.info('ND : %.5f ' % test_metrics['ND'])
        early_stopping(test_metrics['ND'], model)
        
        if early_stopping.early_stop:
            logger.info('Early stopping')
            break
        
#     # load the last checkpoint with the best model
#     model.load_state_dict(torch.load('checkpoint.pt'))

    if args.save_best:
        f = open('./param_search.txt', 'w')
        f.write('-----------\n')
        list_of_params = args.search_params.split(',')
        print_params = ''
        for param in list_of_params:
            param_value = getattr(params, param)
            print_params += f'{param}: {param_value:.2f}'
        print_params = print_params[:-1]
        f.write(print_params + '\n')
        f.write('Best ND: ' + str(best_test_ND) + '\n')
        logger.info(print_params)
        logger.info(f'Best ND: {best_test_ND}')
        f.close()
        utils.plot_all_epoch(ND_summary, print_params + '_ND', location=params.plot_dir)
        utils.plot_all_epoch(loss_summary, print_params + '_loss', location=params.plot_dir)
Example #5
0
def train(train_loader, val_loader, class_weights, class_encoding):
    print("Training...")
    num_classes = len(class_encoding)
    model = ERFNet(num_classes)
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.Adam(model.parameters(),
                           lr=args.learning_rate,
                           weight_decay=args.weight_decay)
    # Learning rate decay scheduler
    lr_updater = lr_scheduler.StepLR(optimizer, args.lr_decay_epochs,
                                     args.lr_decay)

    # Evaluation metric
    if args.ignore_unlabeled:
        ignore_index = list(class_encoding).index('unlabeled')
    else:
        ignore_index = None

    metric = IoU(num_classes, ignore_index=ignore_index)

    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    # Optionally resume from a checkpoint
    if args.resume:
        model, optimizer, start_epoch, best_miou, val_miou, train_miou, val_loss, train_loss = utils.load_checkpoint(
            model, optimizer, args.save_dir, args.name, True)
        print(
            "Resuming from model: Start epoch = {0} | Best mean IoU = {1:.4f}".
            format(start_epoch, best_miou))
    else:
        start_epoch = 0
        best_miou = 0
        val_miou = []
        train_miou = []
        val_loss = []
        train_loss = []

    # Start Training
    train = Train(model, train_loader, optimizer, criterion, metric, use_cuda)
    val = Test(model, val_loader, criterion, metric, use_cuda)

    for epoch in range(start_epoch, args.epochs):
        print(">> [Epoch: {0:d}] Training".format(epoch))
        lr_updater.step()
        epoch_loss, (iou, miou) = train.run_epoch(args.print_step)
        print(
            ">> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".format(
                epoch, epoch_loss, miou))
        train_loss.append(epoch_loss)
        train_miou.append(miou)

        #preform a validation test
        if (epoch + 1) % 10 == 0 or epoch + 1 == args.epochs:
            print(">>>> [Epoch: {0:d}] Validation".format(epoch))
            loss, (iou, miou) = val.run_epoch(args.print_step)
            print(">>>> [Epoch: {0:d}] Avg. loss: {1:.4f} | Mean IoU: {2:.4f}".
                  format(epoch, loss, miou))
            val_loss.append(loss)
            val_miou.append(miou)
            # Print per class IoU on last epoch or if best iou
            if epoch + 1 == args.epochs or miou > best_miou:
                for key, class_iou in zip(class_encoding.keys(), iou):
                    print("{0}: {1:.4f}".format(key, class_iou))
            # Save the model if it's the best thus far
            if miou > best_miou:
                print("Best model thus far. Saving...")
                best_miou = miou
                utils.save_checkpoint(model, optimizer, epoch + 1, best_miou,
                                      val_miou, train_miou, val_loss,
                                      train_loss, args)

    return model, train_loss, train_miou, val_loss, val_miou
Example #6
0
def main():
    global args, best_prec_result

    utils.default_model_dir = args.dir
    start_time = time.time()

    Source_train_loader, Source_test_loader = dataset_selector(args.sd)
    Target_train_loader, Target_test_loader = dataset_selector(args.td)

    state_info = utils.model_optim_state_info()
    state_info.model_init()
    state_info.model_cuda_init()

    if cuda:
        # os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        print("USE", torch.cuda.device_count(), "GPUs!")
        state_info.weight_cuda_init()
        cudnn.benchmark = True
    else:
        print("NO GPU")

    state_info.optimizer_init(lr=args.lr,
                              b1=args.b1,
                              b2=args.b2,
                              weight_decay=args.weight_decay)

    # adversarial_loss = torch.nn.BCELoss()
    criterion_GAN = torch.nn.MSELoss()
    criterion_cycle = torch.nn.L1Loss()
    criterion_identity = torch.nn.L1Loss()
    criterion = nn.CrossEntropyLoss().cuda()

    start_epoch = 0

    checkpoint = utils.load_checkpoint(utils.default_model_dir)
    if not checkpoint:
        state_info.learning_scheduler_init(args)
    else:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint)
        state_info.learning_scheduler_init(args, load_epoch=start_epoch)

    realA_sample_iter = iter(Source_train_loader)
    realB_sample_iter = iter(Target_train_loader)

    realA_sample = to_var(realA_sample_iter.next()[0], FloatTensor)
    realB_sample = to_var(realB_sample_iter.next()[0], FloatTensor)

    for epoch in range(args.epoch):

        train(state_info, Source_train_loader, Target_train_loader,
              criterion_GAN, criterion_cycle, criterion_identity, criterion,
              epoch)
        prec_result = test(state_info, Source_test_loader, Target_test_loader,
                           criterion, realA_sample, realB_sample, epoch)

        if prec_result > best_prec_result:
            best_prec_result = prec_result
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, filename,
                                        utils.default_model_dir, epoch)

        filename = 'latest.pth.tar'
        utils.save_state_checkpoint(state_info, best_prec_result, filename,
                                    utils.default_model_dir, epoch)
        state_info.learning_step()

    now = time.gmtime(time.time() - start_time)
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))
    d_summary = utils.summary({d_loss: 'd_loss'})
    g_summary = utils.summary({g_loss: 'g_loss'})

    # sample
    f_sample = generator(z, size=output_image_size, training=False)

# session
sess = utils.session()
# iteration counter
it_cnt, update_cnt = utils.counter()
# saver
saver = tf.train.Saver(max_to_keep=5)
# summary writer
summary_writer = tf.summary.FileWriter(log_path, sess.graph)

if not utils.load_checkpoint(ckpt_path, sess):
    sess.run(tf.global_variables_initializer())

try:
    z_ipt_sample = np.random.normal(
        size=[output_num_sqrt * output_num_sqrt, z_dim])

    batch_epoch = len(data_pool) // (batch_size)
    max_it = max_epoch * batch_epoch
    for it in range(sess.run(it_cnt), max_it):
        sess.run(update_cnt)

        # which epoch
        epoch = it // batch_epoch
        # it_epoch = it % batch_epoch + 1
        it_epoch = it % batch_epoch
crop_size = args.crop_size


""" run """
with tf.Session() as sess:
    a_real = tf.placeholder(tf.float32, shape=[None, crop_size, crop_size, 3])
    b_real = tf.placeholder(tf.float32, shape=[None, crop_size, crop_size, 3])

    a2b = models.generator(a_real, 'a2b')
    b2a = models.generator(b_real, 'b2a')
    b2a2b = models.generator(b2a, 'a2b', reuse=True)
    a2b2a = models.generator(a2b, 'b2a', reuse=True)

    # retore
    saver = tf.train.Saver()
    ckpt_path = utils.load_checkpoint('./checkpoints/' + dataset, sess, saver)
    if ckpt_path is None:
        raise Exception('No checkpoint!')
    else:
        print('Copy variables from % s' % ckpt_path)

    # test
    a_list = glob('./datasets/' + dataset + '/testA/*.jpg')
    b_list = glob('./datasets/' + dataset + '/testB/*.jpg')

    a_save_dir = './test_predictions/' + dataset + '/testA/'
    b_save_dir = './test_predictions/' + dataset + '/testB/'
    utils.mkdir([a_save_dir, b_save_dir])
    for i in range(len(a_list)):
        a_real_ipt = im.imresize(im.imread(a_list[i]), [crop_size, crop_size])
        a_real_ipt.shape = 1, crop_size, crop_size, 3
Example #9
0
            print( "y_train_fold.shape={}, y_valid_fold.shape={}".format(y_train_fold.shape, y_valid_fold.shape ) )

        #--------------------
        # モデル定義
        #--------------------
        # resnet50
        resnet = ResNet50(
            image_height = args.image_height,
            image_width = args.image_width,
            n_channles = 3,
            pretrained = args.pretrained,
            train_only_fc = args.train_only_fc,
        ).finetuned_resnet50
        resnet.compile( loss = 'categorical_crossentropy', optimizer = optimizers.Adam( lr = 0.0001, beta_1 = 0.5, beta_2 = 0.999 ), metrics = ['accuracy'] )
        if not args.load_checkpoints_path == '' and os.path.exists(args.load_checkpoints_path):
            load_checkpoint(resnet, args.load_checkpoints_path )

        resnet_classifier = ImageClassifierKeras(
            resnet,
            n_classes = 2,
            debug = args.debug,
        )

        # その他機械学習モデル
        knn_classifier = ImageClassifierSklearn( 
            KNeighborsClassifier( n_neighbors = 2, p = 2, metric = 'minkowski', n_jobs = -1 ),
            debug = args.debug,
        )

        svm_classifier = ImageClassifierSklearn(
            SVC( kernel = 'rbf',gamma = 10.0, C = 0.1, probability = True, random_state = args.seed ),
Example #10
0
torch.manual_seed(hps.train.seed)
device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 8, 'pin_memory': True} if use_cuda else {}
train_dataset = LJspeechDataset(hps.data.data_path, True, 0.1)
test_dataset = LJspeechDataset(hps.data.data_path, False, 0.1)
train_loader = DataLoader(train_dataset, batch_size=hps.train.batch_size, shuffle=True, collate_fn=collate_fn,
                          **kwargs)
test_loader = DataLoader(test_dataset, batch_size=hps.train.batch_size, collate_fn=collate_fn,
                          **kwargs)
iteration = 1
generator = models.Generator(hps.data.n_channels).to(device)
discriminator = models.MultiScaleDiscriminator().to(device)
optimizer_g = optim.Adam(generator.parameters(), lr=hps.train.learning_rate)
optimizer_d = optim.Adam(discriminator.parameters(), lr=hps.train.learning_rate)
generator, optimizer_g, _, iteration = utils.load_checkpoint("logs/exp1/G_488.pth", generator, optimizer_g)
discriminator, optimizer_d, _, _ = utils.load_checkpoint("logs/exp1/D_488.pth", discriminator, optimizer_d)

def feature_matching_loss(rs_t, rs_f):
  l_tot = 0
  for d_t, d_f in zip(rs_t, rs_f):
    l_tot += torch.mean(torch.abs(d_t - d_f))
  return l_tot


def train(epoch):
  generator.train()
  discriminator.train()
  train_loss = 0
  for batch_idx, (x, c, _) in enumerate(train_loader):
    x, c = x.to(device), c.to(device)
Example #11
0
def train_and_evaluate(model,
                       train_dataloader,
                       val_dataloader,
                       optimizer,
                       loss_fn,
                       metrics,
                       params,
                       model_dir,
                       restore_file=None,
                       cuda_id=0):

    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(args.model_dir,
                                    args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0
    '''
    # train add logger,epoch two parameters
    '''
    logger = Logger('./logs')
    scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=1)

    for epoch in range(params.num_epochs):
        # Run one epoch
        #         scheduler.step()
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params,
              logger, epoch, cuda_id)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params,
                               cuda_id)

        val_acc = val_metrics['PSNR']
        is_best = val_acc >= best_val_acc

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir,
                                          "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir,
                                      "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)

    plt.plot(global_loss)
    plt.savefig("final loss.jpg")
Example #12
0
if __name__ == '__main__':
  print(colored("\n==>", 'blue'), emojify("Parsing arguments ... :hammer:\n"))
  assert args.reso % 32 == 0, "Resolution must be interger times of 32"
  for arg in vars(args):
    print(arg, ':', getattr(args, arg))

  print(colored("\n==>", 'blue'), emojify("Prepare data ... :coffee:\n"))
  img_datasets, dataloader = prepare_demo_dataset(config.demo['images_dir'], args.reso)
  print("Number of demo images:", len(img_datasets))

  print(colored("\n==>", 'blue'), emojify("Loading network ... :hourglass:\n"))
  yolo = YOLOv3(cfg, args.reso).cuda()
  start_epoch, start_iteration = args.checkpoint.split('.')
  start_epoch, start_iteration, state_dict = load_checkpoint(
    opj(config.CKPT_ROOT, args.dataset),
    int(start_epoch),
    int(start_iteration)
  )
  yolo.load_state_dict(state_dict)
  print("Model starts training from epoch %d iteration %d" % (start_epoch, start_iteration))

  print(colored("\n==>", 'blue'), emojify("Evaluation ...\n"))
  yolo.eval()
  for batch_idx, (img_path, inputs) in enumerate(tqdm(dataloader, ncols=80)):
    inputs = inputs.cuda()
    detections = yolo(inputs)
  
    # take idx 0
    detections = detections[detections[:, 0] == 0]
    path = img_path[0]
max_size = 15

# Tensorboard to get nice plots etc
writer = SummaryWriter(f"runs/loss_plot2")
step = 0

encoder_net = Encoder(hidden_size, num_layers).to(device)
decoder_net = Decoder(hidden_size, num_layers).to(device)

model = Seq2Seq(encoder_net, decoder_net).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

criterion = nn.CrossEntropyLoss()

if load_model:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model, optimizer)

# following is for testing the network, uncomment this if you want
# to try out a few arrays interactively
# sort_array(encoder_net, decoder_net, device)

dataset = SortArray(batch_size, min_int, max_int, min_size, max_size)
train_loader = DataLoader(dataset, batch_size=1, shuffle=False)

for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    if save_model:
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
# -

# Set random seed to the specified value
np.random.seed(utils.random_seed)
torch.manual_seed(utils.random_seed)

# ## Loading data and model

# Read the data (already processed, just like the model trained on)
ALS_df = pd.read_csv(f'{data_path}cleaned/FCUL_ALS_cleaned.csv')

# Drop the unnamed index column
ALS_df.drop(columns=['Unnamed: 0', 'niv'], inplace=True)

# Load the model with the best validation performance
model = utils.load_checkpoint(f'{model_path}checkpoint_07_06_2019_23_14.pth')

# ## Getting train and test sets, in tensor format

# Dictionary containing the sequence length (number of temporal events) of each sequence (patient)
seq_len_df = ALS_df.groupby('subject_id').ts.count().to_frame().sort_values(
    by='ts', ascending=False)
seq_len_dict = dict([
    (idx, val[0])
    for idx, val in list(zip(seq_len_df.index, seq_len_df.values))
])

# +
n_patients = ALS_df.subject_id.nunique()  # Total number of patients
n_inputs = len(ALS_df.columns)  # Number of input features
padding_value = 999999  # Value to be used in the padding
Example #15
0
                           embedding_dim,
                           hidden_size,
                           unit='lstm',
                           layer_num=2)
generative.init_hidden(batch_size, use_gpu)
if use_gpu:
    generative.cuda()

criterion = nn.NLLLoss()
optimizer = optim.Adam(generative.parameters(), lr=5e-3)
num_epochs = 60
model_filename = 'people-lstm-layer-2.model'
names = people_names()
generative.train()
if os.path.exists(model_filename):
    generative, optimizer, starting_epoch, _ = load_checkpoint(
        model_filename, generative, optimizer)
else:
    starting_epoch = 1
print('starting from epoch %d' % starting_epoch)
try:
    for epoch in range(starting_epoch, num_epochs + 1):
        iterators = tqdm(yield_dataset(names, batch_size=batch_size))
        for data in iterators:
            input_words = word2Variable(data)
            target_words = makeTargetVariable(data)
            generative.init_hidden(batch_size, use_gpu)
            if use_gpu:
                input_words = input_words.cuda()
                target_words = target_words.cuda()
            generative.zero_grad()
            loss = 0
    logging.info("Loading the dataset...")

    # fetch dataloaders
    # train_dl = data_loader.fetch_dataloader('train', params)
    # dev_dl = data_loader.fetch_dataloader('dev', params)
    dataloader = data_loader.fetch_dataloader(args.dataset, params)

    logging.info("- done.")

    # Define the model graph
    model = resnet.ResNet18().cuda() if params.cuda else resnet.ResNet18()

    # fetch loss function and metrics
    metrics = resnet.metrics
    
    logging.info("Starting analysis...")

    # Reload weights from the saved file
    utils.load_checkpoint(os.path.join(args.model_dir, args.restore_file + '.pth.tar'), model)

    # Evaluate and analyze
    softmax_scores, predict_correct, confusion_matrix = model_analysis(model, dataloader, params,
                                                                       args.temperature)

    results = {'softmax_scores': softmax_scores, 'predict_correct': predict_correct,
               'confusion_matrix': confusion_matrix}

    for k, v in results.items():
        filename = args.dataset + '_temp' + str(args.temperature) + '_' + k + '.txt'
        save_path = os.path.join(args.model_dir, filename)
        np.savetxt(save_path, v)
Example #17
0
def train_and_evaluate(netG,
                       netD,
                       train_dataloader,
                       val_dataloader,
                       optimG,
                       optimD,
                       loss_fn,
                       metrics,
                       params,
                       model_dir,
                       restore_file=None,
                       cuda_id=0):
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path_g = os.path.join(args.model_dir, 'best_g' + '.pth.tar')
        restore_path_d = os.path.join(args.model_dir, 'best_d' + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path_g))
        utils.load_checkpoint(restore_path_g, netG, optimG)
        utils.load_checkpoint(restore_path_d, netD, optimD)

    best_val_acc = 0.0
    # train add logger,epoch two parameters
    #     logger = Logger('./logs')

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(netG, netD, optimG, optimD, loss_fn, train_dataloader, metrics,
              params, cuda_id)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(netG, netD, loss_fn, val_dataloader, metrics,
                               params, cuda_id)
        #print ('after val --------')

        val_acc = val_metrics['PSNR']
        is_best = val_acc >= best_val_acc

        #Save weights
        # save G
        flag = 'G'
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': netG.state_dict(),
                'optim_dict': optimG.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir,
            flag=flag)
        flag = 'D'
        # save D
        utils.save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': netD.state_dict(),
                'optim_dict': optimD.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir,
            flag=flag)

        # If best_eval, best_save_path
        if is_best:
            logging.info("- Found new best accuracy")
            best_val_acc = val_acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(model_dir,
                                          "metrics_val_best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(model_dir,
                                      "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)

        if epoch % 100 == 0 and epoch > 99:
            plt.plot(global_loss_g)
            plt.savefig(str(epoch) + " epoch_g.jpg")
            plt.plot(global_loss_d)
            plt.savefig(str(epoch) + " epoch_d.jpg")

    plt.plot(global_loss_g)
    plt.savefig("final loss_g.jpg")
    plt.plot(global_loss_d)
    plt.savefig("final loss_d.jpg")
Example #18
0
def main():
    start_time = time.time()

    utils.init_out_dir()
    last_epoch = utils.get_last_checkpoint_step()
    if last_epoch >= args.epoch:
        exit()
    if last_epoch >= 0:
        my_log('\nCheckpoint found: {}\n'.format(last_epoch))
    else:
        utils.clear_log()
    utils.print_args()

    flow = build_mera()
    flow.train(True)
    my_log('nparams in each RG layer: {}'.format(
        [utils.get_nparams(layer) for layer in flow.layers]))
    my_log('Total nparams: {}'.format(utils.get_nparams(flow)))

    # Use multiple GPUs
    if args.cuda and torch.cuda.device_count() > 1:
        flow = utils.data_parallel_wrap(flow)

    params = [x for x in flow.parameters() if x.requires_grad]
    optimizer = torch.optim.AdamW(params,
                                  lr=args.lr,
                                  weight_decay=args.weight_decay)

    if last_epoch >= 0:
        utils.load_checkpoint(last_epoch, flow, optimizer)

    train_split, val_split, data_info = utils.load_dataset()
    train_loader = torch.utils.data.DataLoader(train_split,
                                               args.batch_size,
                                               shuffle=True,
                                               num_workers=1,
                                               pin_memory=True)

    init_time = time.time() - start_time
    my_log('init_time = {:.3f}'.format(init_time))

    my_log('Training...')
    start_time = time.time()
    for epoch_idx in range(last_epoch + 1, args.epoch + 1):
        for batch_idx, (x, _) in enumerate(train_loader):
            optimizer.zero_grad()

            x = x.to(args.device)
            x, ldj_logit = utils.logit_transform(x)
            log_prob = flow.log_prob(x)
            loss = -(log_prob + ldj_logit) / (args.nchannels * args.L**2)
            loss_mean = loss.mean()
            loss_std = loss.std()

            utils.check_nan(loss_mean)

            loss_mean.backward()
            if args.clip_grad:
                clip_grad_norm_(params, args.clip_grad)
            optimizer.step()

            if args.print_step and batch_idx % args.print_step == 0:
                bit_per_dim = (loss_mean.item() + log(256)) / log(2)
                my_log(
                    'epoch {} batch {} bpp {:.8g} loss {:.8g} +- {:.8g} time {:.3f}'
                    .format(
                        epoch_idx,
                        batch_idx,
                        bit_per_dim,
                        loss_mean.item(),
                        loss_std.item(),
                        time.time() - start_time,
                    ))

        if (args.out_filename and args.save_epoch
                and epoch_idx % args.save_epoch == 0):
            state = {
                'flow': flow.state_dict(),
                'optimizer': optimizer.state_dict(),
            }
            torch.save(state,
                       '{}_save/{}.state'.format(args.out_filename, epoch_idx))

            if epoch_idx > 0 and (epoch_idx - 1) % args.keep_epoch != 0:
                os.remove('{}_save/{}.state'.format(args.out_filename,
                                                    epoch_idx - 1))

        if (args.plot_filename and args.plot_epoch
                and epoch_idx % args.plot_epoch == 0):
            with torch.no_grad():
                do_plot(flow, epoch_idx)
Example #19
0
def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer, loss_fn, metrics, params, model_dir,
                       restore_file=None):
    """Train the model and evaluate every epoch.
    Args:
        model: (torch.nn.Module) the neural network
        train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
        params: (Params) hyperparameters
        model_dir: (string) directory containing config, weights and log
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    """
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(
            args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_q8acc = 0.0
    best_val_q3acc = 0.0

    for epoch in range(params.num_epochs):
        # Run one epoch
        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))

        # compute number of batches in one epoch (one full pass over the training set)
        train(model, optimizer, loss_fn, train_dataloader, metrics, params)

        # Evaluate for one epoch on validation set
        val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params)

        val_q8acc = val_metrics['val_q8accuracy']
        val_q3acc = val_metrics['val_q3accuracy']
        is_q8best = val_q8acc >= best_val_q8acc
        is_q3best = val_q3acc >= best_val_q3acc

        # Save weights
        utils.save_checkpoint({'epoch': epoch + 1,
                               'state_dict': model.state_dict(),
                               'optim_dict': optimizer.state_dict()},
                                is_q3best = is_q3best,
                                is_q8best = is_q8best,
                                checkpoint = model_dir)

        # If best_eval, best_save_path
        if is_q8best:
            logging.info("- Found new best q8 accuracy")
            best_val_q8acc = val_q8acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(
                model_dir, "metrics_val_q8best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)

        if is_q3best:
            logging.info("- Found new best q3 accuracy")
            best_val_q3acc = val_q3acc

            # Save best val metrics in a json file in the model directory
            best_json_path = os.path.join(
                model_dir, "metrics_val_q3best_weights.json")
            utils.save_dict_to_json(val_metrics, best_json_path)


        # Save latest val metrics in a json file in the model directory
        last_json_path = os.path.join(
            model_dir, "metrics_val_last_weights.json")
        utils.save_dict_to_json(val_metrics, last_json_path)
Example #20
0
    def __init__(self, n_obs, n_control, n_latent, n_enc, chkpoint_file=None):

        self.learning_rate = tf.placeholder(tf.float32)
        self.annealing_rate = tf.placeholder(tf.float32)

        # Dimensions
        self.n_output = n_obs
        self.n_obs = n_obs
        self.n_control = n_control
        self.n_latent = n_latent
        self.n_enc = n_enc

        # The placeholder from the input
        self.x = tf.placeholder(tf.float32, [None, None, self.n_obs], name="X")
        self.u = tf.placeholder(tf.float32, [None, None, self.n_control],
                                name="U")

        # Initialize p(z0), p(x|z), q(z'|enc, u, z) and p(z'|z) as well as the mlp that
        # generates a low dimensional encoding of x, called enc
        self._init_generative_dist()
        self._init_start_dist()
        self._init_encoding_mlp()
        self.transition = BaselineTransitionNoKL(self.n_latent, self.n_enc,
                                                 self.n_control)

        # Get the encoded representation of the observations (this makes sense when observations are highdimensional images for example)
        enc = self.get_enc_rep(self.x)

        # Get the latent start state
        q0 = self.get_start_dist(self.x[0])
        z0 = q0.sample()
        log_q0 = q0.log_prob(z0)
        p0 = MultivariateNormalDiag(tf.zeros(tf.shape(z0)),
                                    tf.ones(tf.shape(z0)))
        log_p0 = p0.log_prob(z0)

        # Trajectory rollout in latent space + calculation of KL(q(z'|enc, u, z) || p(z'|z))
        z, log_q, log_p = tf.scan(self.transition.one_step_IAF,
                                  (self.u[:-1], enc[1:]), (z0, log_q0, log_p0))
        self.z = tf.concat([[z0], z], 0)

        # Get the generative distribution p(x|z) + calculation of the reconstruntion error
        # TODO: Including x[0], revert if doesn't work
        # px = self.get_generative_dist(z)
        # rec_loss = -px.log_prob(self.x[1:])

        # TODO: Including x[0], Remove if doesn't work
        px = self.get_generative_dist(self.z)
        rec_loss = -px.log_prob(self.x)

        self.px_mean = px.mean()

        # Generating trajectories given only an initial observation
        gen_z = tf.scan(self.transition.gen_one_step, self.u[:-1], z0)
        self.gen_z = tf.concat([[z0], gen_z], 0)
        gen_px = self.get_generative_dist(self.gen_z)
        self.gen_x_mean = gen_px.mean()

        # TODO: Including x[0], Remove if doesn't work
        log_p = tf.concat([[log_p0], log_p], 0)
        log_q = tf.concat([[log_q0], log_q], 0)

        # Create the losses
        # self.rec_loss = rec_loss
        self.log_p = log_p * self.annealing_rate
        self.log_q = log_q * self.annealing_rate
        self.rec_loss = tf.reduce_mean(rec_loss)
        self.kl_loss = tf.reduce_mean(self.log_q - self.log_p)

        # self.total_loss = tf.reduce_mean(rec_loss + self.log_q - self.log_p)
        self.total_loss = self.total_loss = self.kl_loss + self.rec_loss

        # Use the Adam optimizer with clipped gradients
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate=self.learning_rate)
        grads_and_vars = self.optimizer.compute_gradients(self.total_loss)
        capped_grads_and_vars = [(tf.clip_by_value(grad, -1, 1),
                                  var) if grad is not None else (grad, var)
                                 for (grad, var) in grads_and_vars]
        self.optimizer = self.optimizer.apply_gradients(capped_grads_and_vars)

        # Save weights
        self.saver = tf.train.Saver(tf.global_variables(), max_to_keep=None)

        # Launch the session
        self.sess = tf.InteractiveSession()
        self.sess.run(tf.global_variables_initializer())
        if chkpoint_file:
            utils.load_checkpoint(self.sess, chkpoint_file)
Example #21
0
""" train """
''' init '''
# session
sess = utils.session()
# iteration counter
it_cnt, update_cnt = utils.counter()
# saver
saver = tf.train.Saver(max_to_keep=5)
# summary writer
summary_writer = tf.summary.FileWriter('./summaries/celeba_wgan', sess.graph)

''' initialization '''
ckpt_dir = './checkpoints/celeba_wgan'
utils.mkdir(ckpt_dir + '/')
if not utils.load_checkpoint(ckpt_dir, sess):
    sess.run(tf.global_variables_initializer())

''' train '''
try:
    z_ipt_sample = np.random.normal(size=[100, z_dim])

    batch_epoch = len(data_pool) // (batch_size * n_critic)
    max_it = epoch * batch_epoch
    for it in range(sess.run(it_cnt), max_it):
        sess.run(update_cnt)

        # which epoch
        epoch = it // batch_epoch
        it_epoch = it % batch_epoch + 1
Example #22
0
    params = utils.Params()
    # Set the random seed for reproducible experiments
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    params.seed = args.seed

    # Set the logger
    utils.set_logger()

    # get dataloader
    dataloader = NERDataLoader(params)

    # Define the model
    logging.info('Loading the model...')
    bert_config = BertConfig.from_json_file(
        os.path.join(params.bert_model_dir, 'bert_config.json'))
    model = BertForTokenClassification(bert_config, params=params)
    model.to(params.device)
    # Reload weights from the saved file
    utils.load_checkpoint(
        os.path.join(params.model_dir, args.restore_file + '.pth.tar'), model)
    logging.info('- done.')

    logging.info("Loading the dataset...")
    loader = dataloader.get_dataloader(data_sign=mode)

    logging.info("Starting prediction...")
    # Create the input data pipeline
    predict(model, loader, params, mode)
    logging.info('-done')
Example #23
0
        loaders, w_class, class_encoding = load_dataset(dataset)
        train_loader, val_loader, test_loader = loaders

        if args.mode.lower() in {'train'}:
            model, tl, tmiou, vl, vmiou = train(train_loader, val_loader,
                                                w_class, class_encoding)
            plt.plot(tl, label="train loss")
            plt.plot(tmiou, label="train miou")
            plt.plot(vl, label="val loss")
            plt.plot(vmiou, label="val miou")
            plt.legend()
            plt.xlabel("Epoch")
            plt.ylabel("loss/accuracy")
            plt.grid(True)
            plt.xticks()
            plt.savefig('./plots/train.png')
        elif args.mode.lower() == 'test':
            num_classes = len(class_encoding)
            #model = ENet(num_classes)
            model = ERFNet(num_classes)
            if use_cuda:
                model = model.cuda()
            optimizer = optim.Adam(model.parameters())
            model = utils.load_checkpoint(model, optimizer, args.save_dir,
                                          args.name)[0]
            test(model, test_loader, w_class, class_encoding)
        else:
            raise RuntimeError(
                "\"{0}\" is not a valid choice for execution mode.".format(
                    args.mode))
def run(rank, n_gpus, hps):
    global global_step
    if rank == 0:
        writer = SummaryWriter(log_dir='./')
        writer_eval = SummaryWriter(log_dir='./eval')

    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=n_gpus,
                            rank=rank)
    torch.manual_seed(hps.train.seed)
    torch.cuda.set_device(rank)

    train_dataset = ImageTextLoader(hps.data.training_file_path, hps.data)
    train_sampler = DistributedBucketSampler(train_dataset,
                                             hps.train.num_tokens,
                                             num_replicas=n_gpus,
                                             rank=rank,
                                             shuffle=True)
    collate_fn = ImageTextCollate()
    train_loader = DataLoader(train_dataset,
                              num_workers=4,
                              shuffle=False,
                              pin_memory=True,
                              collate_fn=collate_fn,
                              batch_sampler=train_sampler)
    if rank == 0:
        eval_dataset = ImageTextLoader(hps.data.validation_file_path, hps.data)
        eval_sampler = DistributedBucketSampler(eval_dataset,
                                                hps.train.num_tokens,
                                                num_replicas=1,
                                                rank=rank,
                                                shuffle=True)
        eval_loader = DataLoader(eval_dataset,
                                 num_workers=0,
                                 shuffle=False,
                                 pin_memory=False,
                                 collate_fn=collate_fn,
                                 batch_sampler=eval_sampler)

    model = TableRecognizer(len(train_dataset.vocab),
                            3 * (hps.data.patch_length**2),
                            **hps.model).cuda(rank)
    optim = torch.optim.Adam(model.parameters(),
                             hps.train.learning_rate,
                             betas=hps.train.betas,
                             eps=hps.train.eps)
    model = DDP(model, device_ids=[rank])

    try:
        _, _, _, epoch_str = utils.load_checkpoint(
            utils.latest_checkpoint_path('./', "model_*.pth"), model, optim)
        global_step = (epoch_str - 1) * len(train_loader)
    except:
        epoch_str = 1
        global_step = 0

    scaler = GradScaler(enabled=hps.train.fp16_run)

    for epoch in range(epoch_str, hps.train.epochs + 1):
        if rank == 0:
            train_and_evaluate(rank, epoch, hps, model, optim, scaler,
                               [train_loader, eval_loader],
                               [writer, writer_eval])
        else:
            train_and_evaluate(rank, epoch, hps, model, optim, scaler,
                               [train_loader, None], None)
def main(model_dir, model, dataset, batch_size=128, epochs=[150,250,350]):
    utils.default_model_dir = model_dir
    utils.c = None
    utils.str_w = ''
    # model = model
    lr = 0.1
    start_time = time.time()

    if dataset == 'cifar10':
        if batch_size is 128:
            train_loader, test_loader = utils.cifar10_loader()
        elif batch_size is 64:
            train_loader, test_loader = utils.cifar10_loader_64()
    elif dataset == 'cifar100':
        if batch_size is 128:
            train_loader, test_loader = utils.cifar100_loader()
        elif batch_size is 64:
            train_loader, test_loader = utils.cifar100_loader_64()
    

    if torch.cuda.is_available():
        # os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        print("USE", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model).cuda()
        cudnn.benchmark = True

    else:
        print("NO GPU -_-;")

    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4, nesterov=True)
    criterion = nn.CrossEntropyLoss().cuda()

    start_epoch = 0
    checkpoint = utils.load_checkpoint(model_dir)
    
    if not checkpoint:
        pass
    else:
        start_epoch = checkpoint['epoch'] + 1
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])

    utils.init_learning(model.module)

    for epoch in range(start_epoch, epochs[2]):
        if epoch < epochs[0]:
            learning_rate = lr
        elif epoch < epochs[1]:
            learning_rate = lr * 0.1
        else:
            learning_rate = lr * 0.01
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate

        train(model, optimizer, criterion, train_loader, epoch, True)
        test(model, criterion, test_loader, epoch, True)

        utils.switching_learning(model.module)
        print('switching_learning to Gate')
        
        train(model, optimizer, criterion, train_loader, epoch, False)
        test(model, criterion, test_loader, epoch, False)        

        utils.switching_learning(model.module)
        print('switching_learning to Gate')

        if epoch % 5 == 0:
            model_filename = 'checkpoint_%03d.pth.tar' % epoch
            utils.save_checkpoint({
                'epoch': epoch,
                'model': model,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, model_filename, model_dir)

    now = time.gmtime(time.time() - start_time)
    weight_extract(model, optimizer, criterion, train_loader, epoch)
    utils.conv_weight_L1_printing(model.module)
    
    print('{} hours {} mins {} secs for training'.format(now.tm_hour, now.tm_min, now.tm_sec))
Example #26
0
def train_and_evaluate2(model: nn.Module, train_loader: DataLoader,
                        test_loader: DataLoader, optimizer: optim,
                        params: utils.Params, loss_fn: None,
                        restore_file: None, args: None, idx: None) -> None:
    '''Train the model and evaluate every epoch.
    Args:
        model: (torch.nn.Module) the Deep AR model
        train_loader: load train data and labels
        test_loader: load test data and labels
        optimizer: (torch.optim) optimizer for parameters of model
        loss_fn: a function that takes outputs and labels per timestep, and then computes the loss for the batch
        params: (Params) hyperparameters
        restore_file: (string) optional- name of file to restore from (without its extension .pth.tar)
    '''
    # reload weights from restore_file if specified
    if restore_file is not None:
        restore_path = os.path.join(params.model_dir,
                                    restore_file + '.pth.tar')
        logger.info('Restoring parameters from {}'.format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)
    logger.info('begin training and evaluation')
    best_test_ND = float('inf')

    # File to save first results
    out_file = os.path.join(os.path.join('experiments', args.model_name),
                            'train_results.csv')
    if not os.path.isfile(out_file):
        of_connection = open(out_file, 'w')
        writer = csv.writer(of_connection)
        # Write the headers to the file
        writer.writerow(['iteration', 'epoch', 'test_metric', 'train_loss'])
        of_connection.close()

    train_len = len(train_loader)
    ND_summary = np.zeros(params.num_epochs)
    loss_summary = np.zeros((train_len * params.num_epochs))

    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=5,
                                   verbose=True,
                                   delta=0.0001,
                                   folder=params.model_dir)

    for epoch in range(params.num_epochs):
        logger.info('Epoch {}/{}'.format(epoch + 1, params.num_epochs))
        loss_summary[epoch * train_len:(epoch + 1) * train_len] = train(
            model, optimizer, loss_fn, train_loader, test_loader, params,
            args.sampling, epoch)
        test_metrics = evaluate(model,
                                loss_fn,
                                test_loader,
                                params,
                                epoch,
                                sample=args.sampling)
        if test_metrics['rou50'] == float('nan'):
            test_metrics['rou50'] = 100
        elif test_metrics['rou50'] == 'nan':
            test_metrics['rou50'] = 100
        elif test_metrics['rou50'] == np.nan:
            test_metrics['rou50'] = 100

        ND_summary[epoch] = test_metrics['rou50']
        is_best = ND_summary[epoch] <= best_test_ND

        # Save weights
        utils.save_checkpoint(
            {
                'epoch': 0,  #epoch + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            epoch=0,  # to prevent extra model savings
            is_best=is_best,
            checkpoint=params.model_dir)

        if is_best:
            logger.info('- Found new best ND')
            best_test_ND = ND_summary[epoch]
            best_json_path = os.path.join(params.model_dir,
                                          'metrics_test_best_weights.json')
            utils.save_dict_to_json(test_metrics, best_json_path)

        logger.info('Current Best loss is: %.5f' % best_test_ND)

        #if args.plot_figure:
        #    utils.plot_all_epoch(ND_summary[:epoch + 1], args.dataset + '_ND', params.plot_dir)
        #    utils.plot_all_epoch(loss_summary[:(epoch + 1) * train_len], args.dataset + '_loss', params.plot_dir)

        last_json_path = os.path.join(params.model_dir,
                                      'metrics_test_last_weights.json')
        utils.save_dict_to_json(test_metrics, last_json_path)
        # Write to the csv file ('a' means append)
        of_connection = open(out_file, 'a')
        writer = csv.writer(of_connection)
        writer.writerow([idx, epoch + 1, test_metrics,
                         loss_summary[-1]])  #loss_summary[0]??
        of_connection.close()
        logger.info('Loss_summary: ' %
                    loss_summary[epoch * train_len:(epoch + 1) * train_len])

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        logger.info('test_metrics[rou50]: %.5f ' % test_metrics['rou50'])
        early_stopping(test_metrics['rou50'], model)

        if early_stopping.early_stop:
            logger.info('Early stopping')
            break

    with open(best_json_path) as json_file:
        best_metrics = json.load(json_file)
    return best_metrics, test_metrics
Example #27
0
def main(args, logger):
    # trn_df = pd.read_csv(f'{MNT_DIR}/inputs/origin/train.csv')
    trn_df = pd.read_pickle(f'{MNT_DIR}/inputs/nes_info/trn_df.pkl')
    trn_df['is_original'] = 1

    # clean texts
    # trn_df = clean_data(trn_df, ['question_title', 'question_body', 'answer'])

    gkf = GroupKFold(n_splits=5).split(
        X=trn_df.question_body,
        groups=trn_df.question_body_le,
    )

    histories = {
        'trn_loss': {},
        'val_loss': {},
        'val_metric': {},
        'val_metric_raws': {},
    }
    loaded_fold = -1
    loaded_epoch = -1
    if args.checkpoint:
        histories, loaded_fold, loaded_epoch = load_checkpoint(args.checkpoint)

    fold_best_metrics = []
    fold_best_metrics_raws = []
    for fold, (trn_idx, val_idx) in enumerate(gkf):
        if fold < loaded_fold:
            fold_best_metrics.append(np.max(histories["val_metric"][fold]))
            fold_best_metrics_raws.append(
                histories["val_metric_raws"][fold][np.argmax(
                    histories["val_metric"][fold])])
            continue
        sel_log(
            f' --------------------------- start fold {fold} --------------------------- ',
            logger)
        fold_trn_df = trn_df.iloc[trn_idx]  # .query('is_original == 1')
        fold_trn_df = fold_trn_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        # use only original row
        fold_val_df = trn_df.iloc[val_idx].query('is_original == 1')
        fold_val_df = fold_val_df.drop(['is_original', 'question_body_le'],
                                       axis=1)
        if args.debug:
            fold_trn_df = fold_trn_df.sample(100, random_state=71)
            fold_val_df = fold_val_df.sample(100, random_state=71)
        temp = pd.Series(
            list(
                itertools.chain.from_iterable(
                    fold_trn_df.question_title.apply(lambda x: x.split(' ')) +
                    fold_trn_df.question_body.apply(lambda x: x.split(' ')) +
                    fold_trn_df.answer.apply(lambda x: x.split(' '))))
        ).value_counts()
        tokens = temp[temp >= 10].index.tolist()
        # tokens = []
        tokens = [
            'CAT_TECHNOLOGY'.casefold(),
            'CAT_STACKOVERFLOW'.casefold(),
            'CAT_CULTURE'.casefold(),
            'CAT_SCIENCE'.casefold(),
            'CAT_LIFE_ARTS'.casefold(),
        ]

        trn_dataset = QUESTDataset2(
            df=fold_trn_df,
            mode='train',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=DO_LOWER_CASE,
            Q_LABEL_COL=Q_LABEL_COL,
            A_LABEL_COL=A_LABEL_COL,
            t_max_len=T_MAX_LEN,
            q_max_len=Q_MAX_LEN,
            a_max_len=A_MAX_LEN,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
            rm_zero=RM_ZERO,
        )
        # update token
        trn_sampler = RandomSampler(data_source=trn_dataset)
        trn_loader = DataLoader(trn_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=trn_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=True,
                                pin_memory=True)
        val_dataset = QUESTDataset2(
            df=fold_val_df,
            mode='valid',
            tokens=tokens,
            augment=[],
            tokenizer_type=TOKENIZER_TYPE,
            pretrained_model_name_or_path=TOKENIZER_PRETRAIN,
            do_lower_case=DO_LOWER_CASE,
            Q_LABEL_COL=Q_LABEL_COL,
            A_LABEL_COL=A_LABEL_COL,
            t_max_len=T_MAX_LEN,
            q_max_len=Q_MAX_LEN,
            a_max_len=A_MAX_LEN,
            tqa_mode=TQA_MODE,
            TBSEP='[TBSEP]',
            pos_id_type='arange',
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
            rm_zero=RM_ZERO,
        )
        val_sampler = RandomSampler(data_source=val_dataset)
        val_loader = DataLoader(val_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=val_sampler,
                                num_workers=os.cpu_count(),
                                worker_init_fn=lambda x: np.random.seed(),
                                drop_last=False,
                                pin_memory=True)

        fobj = BCEWithLogitsLoss()
        state_dict = BertModel.from_pretrained(MODEL_PRETRAIN).state_dict()
        model = BertModelForBinaryMultiLabelClassifier2(
            num_labels=len(Q_LABEL_COL) + len(A_LABEL_COL),
            config_path=MODEL_CONFIG_PATH,
            q_state_dict=state_dict,
            a_state_dict=state_dict,
            token_size=len(trn_dataset.tokenizer),
            MAX_SEQUENCE_LENGTH=MAX_SEQ_LEN,
        )
        optimizer = optim.Adam(model.parameters(), lr=3e-5)
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                         T_max=MAX_EPOCH,
                                                         eta_min=1e-5)

        # load checkpoint model, optim, scheduler
        if args.checkpoint and fold == loaded_fold:
            load_checkpoint(args.checkpoint, model, optimizer, scheduler)

        for epoch in tqdm(list(range(MAX_EPOCH))):
            if fold <= loaded_fold and epoch <= loaded_epoch:
                continue
            if epoch < 1:
                model.freeze_unfreeze_bert(freeze=True, logger=logger)
            else:
                model.freeze_unfreeze_bert(freeze=False, logger=logger)
            model = DataParallel(model)
            model = model.to(DEVICE)
            trn_loss = train_one_epoch2(model, fobj, optimizer, trn_loader,
                                        DEVICE)
            val_loss, val_metric, val_metric_raws, val_y_preds, val_y_trues, val_qa_ids = test2(
                model, fobj, val_loader, DEVICE, mode='valid')

            scheduler.step()
            if fold in histories['trn_loss']:
                histories['trn_loss'][fold].append(trn_loss)
            else:
                histories['trn_loss'][fold] = [
                    trn_loss,
                ]
            if fold in histories['val_loss']:
                histories['val_loss'][fold].append(val_loss)
            else:
                histories['val_loss'][fold] = [
                    val_loss,
                ]
            if fold in histories['val_metric']:
                histories['val_metric'][fold].append(val_metric)
            else:
                histories['val_metric'][fold] = [
                    val_metric,
                ]
            if fold in histories['val_metric_raws']:
                histories['val_metric_raws'][fold].append(val_metric_raws)
            else:
                histories['val_metric_raws'][fold] = [
                    val_metric_raws,
                ]

            logging_val_metric_raws = ''
            for val_metric_raw in val_metric_raws:
                logging_val_metric_raws += f'{float(val_metric_raw):.4f}, '

            sel_log(
                f'fold : {fold} -- epoch : {epoch} -- '
                f'trn_loss : {float(trn_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_loss : {float(val_loss.detach().to("cpu").numpy()):.4f} -- '
                f'val_metric : {float(val_metric):.4f} -- '
                f'val_metric_raws : {logging_val_metric_raws}', logger)
            model = model.to('cpu')
            model = model.module
            save_checkpoint(
                f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}',
                model,
                optimizer,
                scheduler,
                histories,
                val_y_preds,
                val_y_trues,
                val_qa_ids,
                fold,
                epoch,
                val_loss,
                val_metric,
            )
        fold_best_metrics.append(np.max(histories["val_metric"][fold]))
        fold_best_metrics_raws.append(
            histories["val_metric_raws"][fold][np.argmax(
                histories["val_metric"][fold])])
        save_and_clean_for_prediction(f'{MNT_DIR}/checkpoints/{EXP_ID}/{fold}',
                                      trn_dataset.tokenizer,
                                      clean=False)
        del model

    # calc training stats
    fold_best_metric_mean = np.mean(fold_best_metrics)
    fold_best_metric_std = np.std(fold_best_metrics)
    fold_stats = f'{EXP_ID} : {fold_best_metric_mean:.4f} +- {fold_best_metric_std:.4f}'
    sel_log(fold_stats, logger)
    send_line_notification(fold_stats)

    fold_best_metrics_raws_mean = np.mean(fold_best_metrics_raws, axis=0)
    fold_raw_stats = ''
    for metric_stats_raw in fold_best_metrics_raws_mean:
        fold_raw_stats += f'{float(metric_stats_raw):.4f},'
    sel_log(fold_raw_stats, logger)
    send_line_notification(fold_raw_stats)

    sel_log('now saving best checkpoints...', logger)
Example #28
0
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus

utils.mkdir(args.result_dir + 'matfile')
utils.mkdir(args.result_dir + 'png')

test_dataset = get_test_data(args.input_dir)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=args.bs,
                         shuffle=False,
                         num_workers=8,
                         drop_last=False)

model_restoration = MIRNet()

utils.load_checkpoint(model_restoration, args.weights)
print("===>Testing using weights: ", args.weights)

model_restoration.cuda()

model_restoration = nn.DataParallel(model_restoration)

model_restoration.eval()

with torch.no_grad():
    psnr_val_rgb = []
    for ii, data_test in enumerate(tqdm(test_loader), 0):
        rgb_noisy = data_test[0].cuda()
        filenames = data_test[1]
        rgb_restored = model_restoration(rgb_noisy)
        rgb_restored = torch.clamp(rgb_restored, 0, 1)
Example #29
0
ssrn = SSRN().cuda()

optimizer = torch.optim.Adam(ssrn.parameters(), lr=hp.ssrn_lr)

start_timestamp = int(time.time() * 1000)
start_epoch = 0
global_step = 0

logger = Logger(args.dataset, 'ssrn')

# load the last checkpoint if exists
last_checkpoint_file_name = get_last_checkpoint_file_name(logger.logdir)
if last_checkpoint_file_name:
    print("loading the last checkpoint: %s" % last_checkpoint_file_name)
    start_epoch, global_step = load_checkpoint(last_checkpoint_file_name, ssrn,
                                               optimizer)


def get_lr():
    return optimizer.param_groups[0]['lr']


def lr_decay(step, warmup_steps=1000):
    new_lr = hp.ssrn_lr * warmup_steps**0.5 * min(
        (step + 1) * warmup_steps**-1.5, (step + 1)**-0.5)
    optimizer.param_groups[0]['lr'] = new_lr


def train(train_epoch, phase='train'):
    global global_step
Example #30
0
        model = MyResNet18(n_in_channels=3, n_fmaps=args.n_fmaps,
                           n_classes=2).to(device)
    elif (args.network_type == "resnet18"):
        model = ResNet18(n_classes=2, pretrained=args.pretrained).to(device)
    else:
        model = ResNet50(n_classes=2,
                         pretrained=args.pretrained,
                         train_only_fc=args.train_only_fc).to(device)

    if (args.debug):
        print("model :\n", model)

    # モデルを読み込む
    if not args.load_checkpoints_path == '' and os.path.exists(
            args.load_checkpoints_path):
        load_checkpoint(model, device, args.load_checkpoints_path)

    #======================================================================
    # optimizer の設定
    #======================================================================
    optimizer = optim.Adam(params=model.parameters(),
                           lr=args.lr,
                           betas=(args.beta1, args.beta2))

    #======================================================================
    # loss 関数の設定
    #======================================================================
    loss_fn = nn.CrossEntropyLoss()

    #======================================================================
    # モデルの学習処理
Example #31
0
num_steps = (params.val_size + 1) // params.batch_size
f_score_simple(model, val_data, data_loader, params, num_steps)


# ## Evaluate

# In[14]:


# Define the model
model = net.Net(params).cuda() if params.cuda else net.Net(params)

loss_fn = net.loss_fn
metrics = net.metrics

logging.info("Starting evaluation")

restore_file = 'best'
# Reload weights from the saved file
r = utils.load_checkpoint(os.path.join(model_dir, restore_file + '.pth.tar'), model)


# In[15]:


# Evaluate
num_steps = (params.test_size + 1) // params.batch_size
test_metrics = evaluate(model, loss_fn, test_data, metrics, data_loader, params, num_steps)

Example #32
0
def run_evaluation(args: Namespace, logger: Logger = None):
    """
    Evaluates a saved model
    :param args: Set of args
    :param logger: Logger saved in save_dir
    """

    # Set up logger
    if logger is not None:
        debug, info = logger.debug, logger.info
    else:
        debug = info = print

    debug(pformat(vars(args)))

    # Load metadata
    metadata = json.load(open(args.data_path, 'r'))

    # Train/val/test split
    train_metadata, remaining_metadata = train_test_split(metadata,
                                                          test_size=0.3,
                                                          random_state=0)
    validation_metadata, test_metadata = train_test_split(remaining_metadata,
                                                          test_size=0.5,
                                                          random_state=0)

    # Load data
    debug('Loading data')

    transform = Compose([
        Augmentation(args.augmentation_length),
        NNGraph(args.num_neighbors),
        Distance(False)
    ])
    test_data = GlassDataset(test_metadata, transform=transform)
    args.atom_fdim = 3
    args.bond_fdim = args.atom_fdim + 1

    # Dataset lengths
    test_data_length = len(test_data)
    debug('test size = {:,}'.format(test_data_length))

    # Convert to iterators
    test_data = DataLoader(test_data, args.batch_size)

    # Get loss and metric functions
    metric_func = get_metric_func(args.metric)

    # Test ensemble of models
    for model_idx in range(args.ensemble_size):

        # Load/build model
        if args.checkpoint_paths is not None:
            debug('Loading model {} from {}'.format(
                model_idx, args.checkpoint_paths[model_idx]))
            model = load_checkpoint(args.checkpoint_paths[model_idx],
                                    args.save_dir,
                                    cuda=args.cuda,
                                    attention_viz=args.attention_viz)
        else:
            debug('Must specify a model to load')
            exit(1)

        debug(model)
        debug('Number of parameters = {:,}'.format(param_count(model)))

        # Evaluate on test set using model with best validation score
        test_scores = []
        for test_runs in range(args.num_test_runs):

            test_batch_scores = evaluate(model=model,
                                         data=test_data,
                                         metric_func=metric_func,
                                         args=args)

            test_scores.append(np.mean(test_batch_scores))

        # Average test score
        avg_test_score = np.mean(test_scores)
        info('Model {} test {} = {:.3f}'.format(model_idx, args.metric,
                                                avg_test_score))
        elif params.teacher == "densenet":
            teacher_model = densenet.DenseNet(depth=100, growthRate=12)
            teacher_checkpoint = 'experiments/base_densenet/best.pth.tar'
            teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "resnext29":
            teacher_model = resnext.CifarResNeXt(cardinality=8, depth=29, num_classes=10)
            teacher_checkpoint = 'experiments/base_resnext29/best.pth.tar'
            teacher_model = nn.DataParallel(teacher_model).cuda()

        elif params.teacher == "preresnet110":
            teacher_model = preresnet.PreResNet(depth=110, num_classes=10)
            teacher_checkpoint = 'experiments/base_preresnet110/best.pth.tar'
            teacher_model = nn.DataParallel(teacher_model).cuda()

        utils.load_checkpoint(teacher_checkpoint, teacher_model)

        # Train the model with KD
        logging.info("Experiment - model version: {}".format(params.model_version))
        logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
        logging.info("First, loading the teacher model and computing its outputs...")
        train_and_evaluate_kd(model, teacher_model, train_dl, dev_dl, optimizer, loss_fn_kd,
                              metrics, params, args.model_dir, args.restore_file)

    # non-KD mode: regular training of the baseline CNN or ResNet-18
    else:
        if params.model_version == "cnn":
            model = net.Net(params).cuda() if params.cuda else net.Net(params)
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            # fetch loss function and metrics
            loss_fn = net.loss_fn
Example #34
0
data_loader = torch.utils.data.DataLoader(imagenet_data,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          num_workers=4)
""" model """
D = models_64x64.DiscriminatorWGANGP(3)
G = models_64x64.Generator(z_dim)
utils.cuda([D, G])

d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
""" load checkpoint """
ckpt_dir = './Checkpoints'
utils.mkdir(ckpt_dir)
try:
    ckpt = utils.load_checkpoint(ckpt_dir)
    start_epoch = ckpt['epoch']
    D.load_state_dict(ckpt['D'])
    G.load_state_dict(ckpt['G'])
    d_optimizer.load_state_dict(ckpt['d_optimizer'])
    g_optimizer.load_state_dict(ckpt['g_optimizer'])
except:
    print(' [*] No checkpoint!')
    start_epoch = 0
""" run """
writer = tensorboardX.SummaryWriter('./summaries/celeba_wgan_gp')

z_sample = Variable(torch.randn(100, z_dim))
z_sample = utils.cuda(z_sample)
for epoch in range(start_epoch, epochs):
    for i, (imgs, _) in enumerate(data_loader):
    args = parser.parse_args()
    cnn_dir = 'experiments/cnn_distill'
    json_path = os.path.join(cnn_dir, 'params.json')
    assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
    params = utils.Params(json_path)

    if args.model == "resnet18":
        model = resnet.ResNet18()
        model_checkpoint = 'experiments/base_resnet18/best.pth.tar'

    elif args.model == "wrn":
        model = wrn.wrn(depth=28, num_classes=10, widen_factor=10, dropRate=0.3)
        model_checkpoint = 'experiments/base_wrn/best.pth.tar'

    elif args.model == "distill_resnext":
        model = resnet.ResNet18()
        model_checkpoint = 'experiments/resnet18_distill/resnext_teacher/best.pth.tar'

    elif args.model == "distill_densenet":
        model = resnet.ResNet18()
        model_checkpoint = 'experiments/resnet18_distill/densenet_teacher/best.pth.tar'

    elif args.model == "cnn":
        model = net.Net(params)
        model_checkpoint = 'experiments/cnn_distill/best.pth.tar'

    utils.load_checkpoint(model_checkpoint, model)

    model_size = count_parameters(model)
    print("Number of parameters in {} is: {}".format(args.model, model_size))
b_test_img_paths = glob('./datasets/' + dataset + '/testB/*.jpg')
a_test_pool = data.ImageData(sess, a_test_img_paths, batch_size, load_size=load_size, crop_size=crop_size)
b_test_pool = data.ImageData(sess, b_test_img_paths, batch_size, load_size=load_size, crop_size=crop_size)

a2b_pool = utils.ItemPool()
b2a_pool = utils.ItemPool()

'''summary'''
summary_writer = tf.summary.FileWriter('./summaries/' + dataset, sess.graph)

'''saver'''
ckpt_dir = './checkpoints/' + dataset
utils.mkdir(ckpt_dir + '/')

saver = tf.train.Saver(max_to_keep=5)
ckpt_path = utils.load_checkpoint(ckpt_dir, sess, saver)
if ckpt_path is None:
    sess.run(tf.global_variables_initializer())
else:
    print('Copy variables from % s' % ckpt_path)

'''train'''
try:
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    batch_epoch = min(len(a_data_pool), len(b_data_pool)) // batch_size
    max_it = epoch * batch_epoch
    for it in range(sess.run(it_cnt), max_it):
        sess.run(update_cnt)