Esempio n. 1
0
def test(train_test_unit, out_dir_root):
    output_dir = osp.join(out_dir_root, train_test_unit.metadata['name'])
    mkdir_if_missing(output_dir)
    sys.stdout = Logger(osp.join(output_dir, 'log_test.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    val_path = train_test_unit.test_dir_img
    val_gt_path = train_test_unit.test_dir_den

    if not args.resume:
        pretrained_model = osp.join(output_dir, 'best_model.h5')
    else:
        if args.resume[-3:] == '.h5':
            pretrained_model = args.resume
        else:
            resume_dir = osp.join(args.resume,
                                  train_test_unit.metadata['name'])
            pretrained_model = osp.join(resume_dir, 'best_model.h5')
    print("Using {} for testing.".format(pretrained_model))

    data_loader = ImageDataLoader(val_path,
                                  val_gt_path,
                                  shuffle=False,
                                  batch_size=1)
    mae, mse = evaluate_model(pretrained_model,
                              data_loader,
                              save_test_results=args.save_plots,
                              plot_save_dir=osp.join(output_dir,
                                                     'plot-results-test/'))

    print("MAE: {0:.4f}, MSE: {1:.4f}".format(mae, mse))
Esempio n. 2
0
def train(train_test_unit, out_dir_root):
    output_dir = osp.join(out_dir_root, train_test_unit.metadata['name'])
    mkdir_if_missing(output_dir)
    sys.stdout = Logger(osp.join(output_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    dataset_name = train_test_unit.metadata['name']
    train_path = train_test_unit.train_dir_img
    train_gt_path = train_test_unit.train_dir_den
    val_path = train_test_unit.test_dir_img
    val_gt_path = train_test_unit.test_dir_den

    # training configuration
    start_step = args.start_epoch
    end_step = args.max_epoch
    lr = args.lr

    # log frequency
    disp_interval = args.train_batch * 20

    # ------------
    rand_seed = args.seed
    if rand_seed is not None:
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        torch.cuda.manual_seed(rand_seed)

    # load net
    net = CrowdCounter()
    if not args.resume:
        network.weights_normal_init(net, dev=0.01)
    else:
        # network.weights_normal_init(net, dev=0.01) #init all layers in case of partial net load
        if args.resume[-3:] == '.h5':
            pretrained_model = args.resume
        else:
            resume_dir = osp.join(args.resume, pu.metadata['name'])
            pretrained_model = osp.join(resume_dir, 'best_model.h5')
        network.load_net(pretrained_model, net)
        print('Will apply fine tunning over', pretrained_model)
    net.cuda()
    net.train()

    optimizer_d_large = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                net.d_large.parameters()),
                                         lr=lr,
                                         betas=(args.beta1, args.beta2))
    optimizer_d_small = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                net.d_small.parameters()),
                                         lr=lr,
                                         betas=(args.beta1, args.beta2))
    optimizer_g_large = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                net.g_large.parameters()),
                                         lr=lr,
                                         betas=(args.beta1, args.beta2))
    optimizer_g_small = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                net.g_small.parameters()),
                                         lr=lr,
                                         betas=(args.beta1, args.beta2))

    # training
    train_loss = 0
    step_cnt = 0
    re_cnt = False
    t = Timer()
    t.tic()

    # preprocess flags
    overlap_test = True if args.overlap_test else False

    data_loader = ImageDataLoader(train_path,
                                  train_gt_path,
                                  shuffle=True,
                                  batch_size=args.train_batch,
                                  test_loader=False)
    data_loader_val = ImageDataLoader(val_path,
                                      val_gt_path,
                                      shuffle=False,
                                      batch_size=1,
                                      test_loader=True,
                                      img_width=args.size_x,
                                      img_height=args.size_y,
                                      test_overlap=overlap_test)
    best_mae = sys.maxsize

    for epoch in range(start_step, end_step + 1):
        step = 0
        train_loss_gen_small = 0
        train_loss_gen_large = 0
        train_loss_dis_small = 0
        train_loss_dis_large = 0

        for blob in data_loader:
            step = step + args.train_batch
            im_data = blob['data']
            gt_data = blob['gt_density']
            idx_data = blob['idx']
            im_data_norm = im_data / 127.5 - 1.  # normalize between -1 and 1
            gt_data = gt_data * args.den_factor

            optimizer_d_large.zero_grad()
            optimizer_d_small.zero_grad()
            density_map = net(im_data_norm,
                              gt_data,
                              epoch=epoch,
                              mode="discriminator")
            loss_d_small = net.loss_dis_small
            loss_d_large = net.loss_dis_large
            loss_d_small.backward()
            loss_d_large.backward()
            optimizer_d_small.step()
            optimizer_d_large.step()

            optimizer_g_large.zero_grad()
            optimizer_g_small.zero_grad()
            density_map = net(im_data_norm,
                              gt_data,
                              epoch=epoch,
                              mode="generator")
            loss_g_small = net.loss_gen_small
            loss_g_large = net.loss_gen_large
            loss_g = net.loss_gen
            loss_g.backward()  # loss_g_large + loss_g_small
            optimizer_g_small.step()
            optimizer_g_large.step()

            density_map /= args.den_factor
            gt_data /= args.den_factor

            train_loss_gen_small += loss_g_small.data.item()
            train_loss_gen_large += loss_g_large.data.item()
            train_loss_dis_small += loss_d_small.data.item()
            train_loss_dis_large += loss_d_large.data.item()

            step_cnt += 1
            if step % disp_interval == 0:
                duration = t.toc(average=False)
                fps = step_cnt / duration
                density_map = density_map.data.cpu().numpy()
                train_batch_size = gt_data.shape[0]
                gt_count = np.sum(gt_data.reshape(train_batch_size, -1),
                                  axis=1)
                et_count = np.sum(density_map.reshape(train_batch_size, -1),
                                  axis=1)

                if args.save_plots:
                    plot_save_dir = osp.join(output_dir, 'plot-results-train/')
                    mkdir_if_missing(plot_save_dir)
                    utils.save_results(im_data,
                                       gt_data,
                                       density_map,
                                       idx_data,
                                       plot_save_dir,
                                       loss=args.loss)

                print(
                    "epoch: {0}, step {1}/{5}, Time: {2:.4f}s, gt_cnt: {3:.4f}, et_cnt: {4:.4f}, mean_diff: {6:.4f}"
                    .format(epoch, step, 1. / fps, gt_count[0], et_count[0],
                            data_loader.num_samples,
                            np.mean(np.abs(gt_count - et_count))))
                re_cnt = True

            if re_cnt:
                t.tic()
                re_cnt = False

        save_name = os.path.join(
            output_dir, '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                             dataset_name, epoch))
        network.save_net(save_name, net)

        # calculate error on the validation dataset
        mae, mse = evaluate_model(save_name,
                                  data_loader_val,
                                  epoch=epoch,
                                  den_factor=args.den_factor)
        if mae < best_mae:
            best_mae = mae
            best_mse = mse
            best_model = '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                              dataset_name, epoch)
            network.save_net(os.path.join(output_dir, "best_model.h5"), net)

        print(
            "Epoch: {0}, MAE: {1:.4f}, MSE: {2:.4f}, loss gen small: {3:.4f}, loss gen large: {4:.4f}, loss dis small: {5:.4f}, loss dis large: {6:.4f}, loss: {7:.4f}"
            .format(
                epoch, mae, mse, train_loss_gen_small, train_loss_gen_large,
                train_loss_dis_small, train_loss_dis_large,
                train_loss_gen_small + train_loss_gen_large +
                train_loss_dis_small + train_loss_dis_large))
        print("Best MAE: {0:.4f}, Best MSE: {1:.4f}, Best model: {2}".format(
            best_mae, best_mse, best_model))
def train(train_test_unit, out_dir_root):
    if args.model in ['mcnn4-gan', 'mcnn4-gan-skip', 'mcnn4-gan-u']:
        train_gan(train_test_unit, out_dir_root, args)
        return

    output_dir = osp.join(out_dir_root, train_test_unit.metadata['name'])
    mkdir_if_missing(output_dir)
    output_dir_model = osp.join(output_dir, 'models')
    mkdir_if_missing(output_dir_model)
    if args.resume:
        sys.stdout = Logger(osp.join(output_dir, 'log_train.txt'), mode='a')
    else:
        sys.stdout = Logger(osp.join(output_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    dataset_name = train_test_unit.metadata['name']
    train_path = train_test_unit.train_dir_img
    train_gt_path = train_test_unit.train_dir_den
    val_path = train_test_unit.val_dir_img
    val_gt_path = train_test_unit.val_dir_den

    #training configuration
    start_step = args.start_epoch
    end_step = args.max_epoch
    lr = args.lr

    #log frequency
    disp_interval = args.train_batch * 20

    # ------------
    rand_seed = args.seed
    if rand_seed is not None:
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        torch.cuda.manual_seed(rand_seed)

    best_mae = sys.maxsize  # best mae
    current_patience = 0

    # load net
    net = CrowdCounter(model=args.model)
    if not args.resume:
        network.weights_normal_init(net, dev=0.01)

    else:
        if args.resume[-3:] == '.h5':  #don't use this option!
            pretrained_model = args.resume
        else:
            resume_dir = osp.join(args.resume,
                                  train_test_unit.metadata['name'])
            if args.last_model:
                pretrained_model = osp.join(resume_dir, 'last_model.h5')
                f = open(osp.join(resume_dir, "current_values.bin"), "rb")
                current_patience = pickle.load(f)
                f.close()
            else:
                pretrained_model = osp.join(resume_dir, 'best_model.h5')
                current_patience = 0
            f = open(osp.join(resume_dir, "best_values.bin"), "rb")
            best_mae, best_mse, best_model, _ = pickle.load(f)
            f.close()
            print(
                "Best MAE: {0:.4f}, Best MSE: {1:.4f}, Best model: {2}, Current patience: {3}"
                .format(best_mae, best_mse, best_model, current_patience))

        network.load_net(pretrained_model, net)
        print('Will apply fine tunning over', pretrained_model)
    net.cuda()
    net.train()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        net.parameters()),
                                 lr=lr)

    # training
    train_loss = 0
    step_cnt = 0
    re_cnt = False
    t = Timer()
    t.tic()

    data_loader = ImageDataLoader(train_path,
                                  train_gt_path,
                                  shuffle=True,
                                  batch_size=args.train_batch)
    data_loader_val = ImageDataLoader(val_path,
                                      val_gt_path,
                                      shuffle=False,
                                      batch_size=1)

    for epoch in range(start_step, end_step + 1):
        step = 0
        train_loss = 0
        for blob in data_loader:
            optimizer.zero_grad()
            step = step + args.train_batch
            im_data = blob['data']
            gt_data = blob['gt_density']
            im_data_norm = im_data / 127.5 - 1.  #normalize between -1 and 1
            gt_data *= args.den_scale_factor
            density_map = net(im_data_norm, gt_data=gt_data)
            loss = net.loss
            loss.backward()
            optimizer.step()
            train_loss += loss.data.item()
            density_map = density_map.data.cpu().numpy()
            density_map /= args.den_scale_factor
            gt_data /= args.den_scale_factor

            step_cnt += 1
            if step % disp_interval == 0:
                duration = t.toc(average=False)
                fps = step_cnt / duration
                train_batch_size = gt_data.shape[0]
                gt_count = np.sum(gt_data.reshape(train_batch_size, -1),
                                  axis=1)
                et_count = np.sum(density_map.reshape(train_batch_size, -1),
                                  axis=1)

                print(
                    "epoch: {0}, step {1}/{5}, Time: {2:.4f}s, gt_cnt[0]: {3:.4f}, et_cnt[0]: {4:.4f}, mean_diff: {6:.4f}"
                    .format(epoch, step, 1. / fps, gt_count[0], et_count[0],
                            data_loader.num_samples,
                            np.mean(np.abs(gt_count - et_count))))
                re_cnt = True

            if re_cnt:
                t.tic()
                re_cnt = False

        save_name = os.path.join(
            output_dir_model, '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                                   dataset_name, epoch))
        network.save_net(save_name, net)
        network.save_net(os.path.join(output_dir, "last_model.h5"), net)

        #calculate error on the validation dataset
        mae, mse = evaluate_model(save_name,
                                  data_loader_val,
                                  model=args.model,
                                  save_test_results=args.save_plots,
                                  plot_save_dir=osp.join(
                                      output_dir, 'plot-results-train/'),
                                  den_scale_factor=args.den_scale_factor)
        if mae < best_mae:
            best_mae = mae
            best_mse = mse
            current_patience = 0
            best_model = '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                              dataset_name, epoch)
            network.save_net(os.path.join(output_dir, "best_model.h5"), net)
            f = open(os.path.join(output_dir, "best_values.bin"), "wb")
            pickle.dump((best_mae, best_mse, best_model, current_patience), f)
            f.close()

        else:
            current_patience += 1

        f = open(os.path.join(output_dir, "current_values.bin"), "wb")
        pickle.dump(current_patience, f)
        f.close()

        print("Epoch: {0}, MAE: {1:.4f}, MSE: {2:.4f}, loss: {3:.4f}".format(
            epoch, mae, mse, train_loss))
        print("Best MAE: {0:.4f}, Best MSE: {1:.4f}, Best model: {2}".format(
            best_mae, best_mse, best_model))
        print("Patience: {0}/{1}".format(current_patience, args.patience))
        sys.stdout.close_open()

        if current_patience > args.patience and args.patience > -1:
            break
Esempio n. 4
0
def train(train_test_unit, out_dir_root):
    output_dir = osp.join(out_dir_root, train_test_unit.metadata['name'])
    mkdir_if_missing(output_dir)
    sys.stdout = Logger(osp.join(output_dir, 'log_train.txt'))
    print("==========\nArgs:{}\n==========".format(args))

    dataset_name = train_test_unit.metadata['name']
    train_path = train_test_unit.train_dir_img
    train_gt_path = train_test_unit.train_dir_den
    val_path = train_test_unit.test_dir_img
    val_gt_path = train_test_unit.test_dir_den

    #training configuration
    start_step = args.start_epoch
    end_step = args.max_epoch
    lr = args.lr

    #log frequency
    disp_interval = args.train_batch * 20

    # ------------
    rand_seed = args.seed
    if rand_seed is not None:
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        torch.cuda.manual_seed(rand_seed)

    # load net
    net = CrowdCounter()
    if not args.resume:
        network.weights_normal_init(net, dev=0.01)
    else:
        if args.resume[-3:] == '.h5':
            pretrained_model = args.resume
        else:
            resume_dir = osp.join(args.resume,
                                  train_test_unit.metadata['name'])
            pretrained_model = osp.join(resume_dir, 'best_model.h5')
        network.load_net(pretrained_model, net)
        print('Will apply fine tunning over', pretrained_model)
    net.cuda()
    net.train()

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        net.parameters()),
                                 lr=lr)

    # training
    train_loss = 0
    step_cnt = 0
    re_cnt = False
    t = Timer()
    t.tic()

    data_loader = ImageDataLoader(train_path,
                                  train_gt_path,
                                  shuffle=True,
                                  batch_size=args.train_batch)
    data_loader_val = ImageDataLoader(val_path,
                                      val_gt_path,
                                      shuffle=False,
                                      batch_size=1)
    best_mae = sys.maxsize

    for epoch in range(start_step, end_step + 1):
        step = 0
        train_loss = 0
        for blob in data_loader:
            optimizer.zero_grad()
            step = step + args.train_batch
            im_data = blob['data']
            gt_data = blob['gt_density']
            im_data_norm = im_data / 255.0
            density_map = net(im_data_norm, gt_data=gt_data)
            loss = net.loss
            loss.backward()
            optimizer.step()
            train_loss += loss.data.item()
            density_map = density_map.data.cpu().numpy()

            step_cnt += 1
            if step % disp_interval == 0:
                duration = t.toc(average=False)
                fps = step_cnt / duration
                train_batch_size = gt_data.shape[0]
                gt_count = np.sum(gt_data.reshape(train_batch_size, -1),
                                  axis=1)
                et_count = np.sum(density_map.reshape(train_batch_size, -1),
                                  axis=1)

                print(
                    "epoch: {0}, step {1}/{5}, Time: {2:.4f}s, gt_cnt[0]: {3:.4f}, et_cnt[0]: {4:.4f}, mean_diff: {6:.4f}"
                    .format(epoch, step, 1. / fps, gt_count[0], et_count[0],
                            data_loader.num_samples,
                            np.mean(np.abs(gt_count - et_count))))
                re_cnt = True

            if re_cnt:
                t.tic()
                re_cnt = False

        save_name = os.path.join(
            output_dir, '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                             dataset_name, epoch))
        network.save_net(save_name, net)

        #calculate error on the validation dataset
        mae, mse = evaluate_model(save_name,
                                  data_loader_val,
                                  save_test_results=args.save_plots,
                                  plot_save_dir=osp.join(
                                      output_dir, 'plot-results-train/'))
        if mae < best_mae:
            best_mae = mae
            best_mse = mse
            best_model = '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                              dataset_name, epoch)
            network.save_net(os.path.join(output_dir, "best_model.h5"), net)

        print("Epoch: {0}, MAE: {1:.4f}, MSE: {2:.4f}, loss: {3:.4f}".format(
            epoch, mae, mse, train_loss))
        print("Best MAE: {0:.4f}, Best MSE: {1:.4f}, Best model: {2}".format(
            best_mae, best_mse, best_model))
def train_gan(train_test_unit, out_dir_root, args):
    output_dir = osp.join(out_dir_root, train_test_unit.metadata['name'])
    mkdir_if_missing(output_dir)
    output_dir_model = osp.join(output_dir, 'models')
    mkdir_if_missing(output_dir_model)
    if args.resume:
        sys.stdout = Logger(osp.join(output_dir, 'log_train.txt'), mode='a')
        plotter = LossPlotter(output_dir, mode='a')
    else:
        sys.stdout = Logger(osp.join(output_dir, 'log_train.txt'))
        plotter = LossPlotter(output_dir, mode='w')
    print("==========\nArgs:{}\n==========".format(args))

    dataset_name = train_test_unit.metadata['name']
    train_path = train_test_unit.train_dir_img
    train_gt_path = train_test_unit.train_dir_den
    val_path = train_test_unit.val_dir_img
    val_gt_path = train_test_unit.val_dir_den

    #training configuration
    start_step = args.start_epoch
    end_step = args.max_epoch

    #log frequency
    disp_interval = args.train_batch * 20

    # ------------
    rand_seed = args.seed
    if rand_seed is not None:
        np.random.seed(rand_seed)
        torch.manual_seed(rand_seed)
        torch.cuda.manual_seed(rand_seed)

    best_mae = sys.maxsize  # best mae
    current_patience = 0

    mse_criterion = nn.MSELoss()

    # load net and optimizer
    net = CrowdCounter(model=args.model, channel_param=args.channel_param)
    net.cuda()
    net.train()
    #optimizerG = torch.optim.RMSprop(filter(lambda p: p.requires_grad, net.net.parameters()), lr=lr)
    #optimizerD = torch.optim.RMSprop(filter(lambda p: p.requires_grad, net.gan_net.parameters()), lr=lrc)
    optimizerG, optimizerD = get_optimizers(args, net)
    if args.reduce_lr_on_plateau:
        schedulerG = lr_scheduler.ReduceLROnPlateau(
            optimizerG,
            patience=args.scheduler_patience,
            factor=args.scheduler_factor,
            cooldown=args.scheduler_cooldown,
            min_lr=args.min_lr,
            verbose=True)
        schedulerD = lr_scheduler.ReduceLROnPlateau(
            optimizerD,
            patience=args.scheduler_patience,
            factor=args.scheduler_factor,
            cooldown=args.scheduler_cooldown,
            min_lr=args.min_lrc,
            verbose=True)
    elif args.step_lr:
        schedulerG = lr_scheduler.StepLR(optimizerG,
                                         args.scheduler_step_size,
                                         gamma=args.scheduler_gamma,
                                         verbose=True)
        schedulerD = lr_scheduler.StepLR(optimizerD,
                                         args.scheduler_step_size,
                                         gamma=args.scheduler_gamma,
                                         verbose=True)
    if not args.resume:
        network.weights_normal_init(net.net, dev=0.01)

    else:
        if args.resume[-3:] == '.h5':  #don't use this option!
            pretrained_model = args.resume
        else:
            resume_dir = osp.join(args.resume,
                                  train_test_unit.metadata['name'])
            if args.last_model:
                pretrained_model = osp.join(resume_dir, 'last_model.h5')
                f = open(osp.join(resume_dir, "current_values.bin"), "rb")
                current_patience = pickle.load(f)
                f.close()
                f = torch.load(osp.join(resume_dir, 'optimizer.pth'))
                optimizerD.load_state_dict(f['opt_d'])
                optimizerG.load_state_dict(f['opt_g'])
                if args.reduce_lr_on_plateau or args.step_lr:
                    schedulerD.load_state_dict(f['sch_d'])
                    schedulerG.load_state_dict(f['sch_g'])
            else:
                pretrained_model = osp.join(resume_dir, 'best_model.h5')
                current_patience = 0
            f = open(osp.join(resume_dir, "best_values.bin"), "rb")
            best_mae, best_mse, best_model, _ = pickle.load(f)
            f.close()
            print(
                "Best MAE: {0:.4f}, Best MSE: {1:.4f}, Best model: {2}, Current patience: {3}"
                .format(best_mae, best_mse, best_model, current_patience))

        network.load_net(pretrained_model, net)
        print('Will apply fine tunning over', pretrained_model)

    # training
    train_lossG = 0
    train_lossD = 0
    step_cnt = 0
    re_cnt = False
    t = Timer()
    t.tic()

    # gan labels
    real_label = 1
    fake_label = 0

    netD = net.gan_net
    netG = net.net

    data_loader = ImageDataLoader(train_path,
                                  train_gt_path,
                                  shuffle=True,
                                  batch_size=args.train_batch,
                                  den_scale=1)
    data_loader_val = ImageDataLoader(val_path,
                                      val_gt_path,
                                      shuffle=False,
                                      batch_size=1,
                                      den_scale=1,
                                      testing=True)

    for epoch in range(start_step, end_step + 1):
        step = 0
        train_lossG = 0
        train_lossD = 0
        train_lossG_mse = 0
        train_lossG_gan = 0

        for blob in data_loader:
            optimizerG.zero_grad()
            optimizerD.zero_grad()
            step = step + args.train_batch
            im_data = blob['data']
            gt_data = blob['gt_density']
            im_data_norm_a = im_data / 127.5 - 1.  #normalize between -1 and 1
            gt_data_a = gt_data * args.den_scale_factor

            errD_epoch = 0

            for critic_epoch in range(args.ncritic):
                im_data_norm = network.np_to_variable(im_data_norm_a,
                                                      is_cuda=True,
                                                      is_training=True)
                gt_data = network.np_to_variable(gt_data_a,
                                                 is_cuda=True,
                                                 is_training=True)

                netD.zero_grad()
                netG.zero_grad()

                #real data discriminator
                b_size = gt_data.size(0)
                output_real = netD(gt_data).view(-1)

                #fake data discriminator
                density_map = netG(im_data_norm)
                output_fake = netD(density_map.detach()).view(-1)

                errD = -(torch.mean(output_real) - torch.mean(output_fake))
                errD.backward()
                optimizerD.step()

                for p in netD.parameters():
                    p.data.clamp_(-0.01, 0.01)

                errD_epoch += errD.data.item()

            errD_epoch /= args.ncritic

            #Generator update
            netG.zero_grad()
            output_fake = netD(density_map).view(-1)
            errG_gan = -torch.mean(output_fake)
            errG_mse = mse_criterion(density_map, gt_data)
            #errG = (1-args.alpha)*errG_mse + args.alpha*errG_gan
            errG = errG_mse + args.alpha * errG_gan
            errG.backward()
            optimizerG.step()

            train_lossG += errG.data.item()
            train_lossG_mse += errG_mse.data.item()
            train_lossG_gan += errG_gan.data.item()
            train_lossD += errD_epoch
            density_map = density_map.data.cpu().numpy()
            density_map /= args.den_scale_factor
            gt_data = gt_data.data.cpu().numpy()
            gt_data /= args.den_scale_factor

            step_cnt += 1
            if step % disp_interval == 0:
                duration = t.toc(average=False)
                fps = step_cnt / duration
                train_batch_size = gt_data.shape[0]
                gt_count = np.sum(gt_data.reshape(train_batch_size, -1),
                                  axis=1)
                et_count = np.sum(density_map.reshape(train_batch_size, -1),
                                  axis=1)

                print(
                    "epoch: {0}, step {1}/{5}, Time: {2:.4f}s, gt_cnt[0]: {3:.4f}, et_cnt[0]: {4:.4f}, mean_diff: {6:.4f}"
                    .format(epoch, step, 1. / fps, gt_count[0], et_count[0],
                            data_loader.num_samples,
                            np.mean(np.abs(gt_count - et_count))))
                re_cnt = True

            if re_cnt:
                t.tic()
                re_cnt = False

        #save model and optimizer
        save_name = os.path.join(
            output_dir_model, '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                                   dataset_name, epoch))
        network.save_net(save_name, net)
        network.save_net(os.path.join(output_dir, "last_model.h5"), net)

        #calculate error on the validation dataset
        mae, mse = evaluate_model(save_name,
                                  data_loader_val,
                                  model=args.model,
                                  save_test_results=args.save_plots,
                                  plot_save_dir=osp.join(
                                      output_dir, 'plot-results-train/'),
                                  den_scale_factor=args.den_scale_factor,
                                  channel_param=args.channel_param)
        if mae < best_mae:
            best_mae = mae
            best_mse = mse
            current_patience = 0
            best_model = '{}_{}_{}.h5'.format(train_test_unit.to_string(),
                                              dataset_name, epoch)
            network.save_net(os.path.join(output_dir, "best_model.h5"), net)
            f = open(os.path.join(output_dir, "best_values.bin"), "wb")
            pickle.dump((best_mae, best_mse, best_model, current_patience), f)
            f.close()

        else:
            current_patience += 1

        f = open(os.path.join(output_dir, "current_values.bin"), "wb")
        pickle.dump(current_patience, f)
        f.close()

        # update lr
        if args.reduce_lr_on_plateau:
            schedulerD.step(train_lossG_mse)
            schedulerG.step(train_lossG_mse)
        elif args.step_lr:
            schedulerD.step()
            schedulerG.step()
        optim_dict = {
            "opt_d": optimizerD.state_dict(),
            "opt_g": optimizerG.state_dict()
        }
        if args.reduce_lr_on_plateau or args.step_lr:
            optim_dict['sch_d'] = schedulerD.state_dict()
            optim_dict['sch_g'] = schedulerG.state_dict()
        torch.save(optim_dict, os.path.join(output_dir, "optimizer.pth"))

        plotter.report(train_lossG_mse, train_lossG_gan, train_lossD)
        plotter.save()
        plotter.plot()

        print(
            "Epoch: {0}, MAE: {1:.4f}, MSE: {2:.4f}, lossG: {3:.4f}, lossG_mse: {4:.4f}, lossG_gan: {5:.4f}, lossD: {6:.4f}"
            .format(epoch, mae, mse, train_lossG, train_lossG_mse,
                    train_lossG_gan, train_lossD))
        print("Best MAE: {0:.4f}, Best MSE: {1:.4f}, Best model: {2}".format(
            best_mae, best_mse, best_model))
        print("Patience: {0}/{1}".format(current_patience, args.patience))
        sys.stdout.close_open()

        if current_patience > args.patience and args.patience > -1:
            break