Exemplo n.º 1
0
def main():
    global args, best_prec_result
    
    utils.default_model_dir = args.dir
    start_time = time.time()

    Source_train_loader, Source_test_loader = dataset_selector(args.sd)
    Target_train_loader, Target_test_loader = dataset_selector(args.td)
    Target_shuffle_loader, _ = dataset_selector(args.td)

    state_info = utils.model_optim_state_info()
    state_info.model_init()
    state_info.model_cuda_init()

    if cuda:
        # os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        print("USE", torch.cuda.device_count(), "GPUs!")
        state_info.weight_cuda_init()
        cudnn.benchmark = True
    else:
        print("NO GPU")

    state_info.optimizer_init(lr=args.lr, b1=args.b1, b2=args.b2, weight_decay=args.weight_decay)

    start_epoch = 0

    checkpoint = utils.load_checkpoint(utils.default_model_dir)    
    if not checkpoint:
        state_info.learning_scheduler_init(args)
    else:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint)
        state_info.learning_scheduler_init(args, load_epoch=start_epoch)

    realS_sample_iter = iter(Source_train_loader)
    realT_sample_iter = iter(Target_train_loader)

    realS_sample = to_var(realS_sample_iter.next()[0], FloatTensor)
    realT_sample = to_var(realT_sample_iter.next()[0], FloatTensor)

    for epoch in range(args.epoch):
        
        train(state_info, Source_train_loader, Target_train_loader, Target_shuffle_loader, epoch)
        prec_result = test(state_info, Source_test_loader, Target_test_loader, realS_sample, realT_sample, epoch)
        
        if prec_result > best_prec_result:
            best_prec_result = prec_result
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, filename, utils.default_model_dir, epoch)

        filename = 'latest.pth.tar'
        utils.save_state_checkpoint(state_info, best_prec_result, filename, utils.default_model_dir, epoch)
        state_info.learning_step() 

    now = time.gmtime(time.time() - start_time)
    utils.print_log('{} hours {} mins {} secs for training'.format(now.tm_hour, now.tm_min, now.tm_sec))
Exemplo n.º 2
0
def main():
    global args, best_prec_result

    start_epoch = 0
    utils.default_model_dir = args.dir
    start_time = time.time()

    train_loader, test_loader, _, _ = dataset_selector(args.sd)

    state_info = utils.model_optim_state_info()
    state_info.model_init(args=args, num_class=10)
    state_info.model_cuda_init()
    # state_info.weight_init()
    state_info.optimizer_init(args)

    if cuda:
        print("USE", torch.cuda.device_count(), "GPUs!")
        cudnn.benchmark = True

    checkpoint = utils.load_checkpoint(utils.default_model_dir, is_last=True)
    if checkpoint:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint)

    for epoch in range(0, args.epoch):
        if epoch < 80:
            lr = args.lr
        elif epoch < 120:
            lr = args.lr * 0.1
        else:
            lr = args.lr * 0.01
        for param_group in state_info.optimizer.param_groups:
            param_group['lr'] = lr

        train(state_info, train_loader, epoch)
        prec_result = test(state_info, test_loader, epoch)

        if prec_result > best_prec_result:
            best_prec_result = prec_result
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, filename,
                                        utils.default_model_dir, epoch)
            utils.print_log('Best Prec : {:.4f}'.format(
                best_prec_result.item()))

        filename = 'latest.pth.tar'
        utils.save_state_checkpoint(state_info, best_prec_result, filename,
                                    utils.default_model_dir, epoch)

    now = time.gmtime(time.time() - start_time)
    utils.print_log('Best Prec : {:.4f}'.format(best_prec_result.item()))
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))

    print('done')
Exemplo n.º 3
0
def train_Epoch(args, state_info, Train_loader, Test_loader):  # all
    start_time = time.time()
    best_prec_result = torch.tensor(0, dtype=torch.float32)
    mode = args.model
    utils.default_model_dir = os.path.join(args.dir, mode)

    start_epoch = 0
    checkpoint = None
    checkpoint = utils.load_checkpoint(utils.default_model_dir)
    if not checkpoint:
        args.last_epoch = -1
        state_info.learning_scheduler_init(args, mode)
    else:
        print("loading {}/{}".format(utils.default_model_dir,
                                     "checkpoint_best.pth.tar"))
        state_info.load_state_dict(checkpoint, mode)
        state_info.learning_scheduler_init(args, mode)
        utils.default_model_dir = os.path.join(utils.default_model_dir, "cls")

    for epoch in range(0, args.epoch):

        epoch_result = train(args, state_info, Train_loader, Test_loader,
                             epoch)

        if epoch_result > best_prec_result:
            best_prec_result = epoch_result
            utils.save_state_checkpoint(state_info, best_prec_result, epoch,
                                        'checkpoint_best.pth.tar',
                                        utils.default_model_dir)
            print('save..')

        if args.use_switch and epoch % args.iter == args.iter - 1:
            utils.switching_learning(state_info.model.module)
            print('learning Gate')
            epoch_result = train(args, state_info, Train_loader, Test_loader,
                                 epoch)

            if epoch_result > best_prec_result:
                best_prec_result = epoch_result
                utils.save_state_checkpoint(state_info, best_prec_result,
                                            epoch, 'checkpoint_best.pth.tar',
                                            utils.default_model_dir)
                print('save..')

            utils.switching_learning(state_info.model.module)
            print('learning Base')

        state_info.lr_model.step()
        utils.print_log('')

    now = time.gmtime(time.time() - start_time)
    utils.print_log('Best Prec : {:.4f}'.format(best_prec_result.item()))
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))
Exemplo n.º 4
0
def train_Base(args, state_info, All_loader, Test_loader):  # all

    best_prec_result = torch.tensor(0, dtype=torch.float32)
    start_time = time.time()
    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
    mode = 'base'
    utils.default_model_dir = os.path.join(args.dir, mode)

    criterion = torch.nn.CrossEntropyLoss()

    start_epoch = 0
    checkpoint = utils.load_checkpoint(utils.default_model_dir)
    if not checkpoint:
        args.last_epoch = -1
        state_info.learning_scheduler_init(args, mode)
    else:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint, mode)
        args.last_epoch = start_epoch
        state_info.learning_scheduler_init(args, mode)

    utils.print_log('Type, Epoch, Batch, loss, BCE, KLD, CE')
    state_info.set_train_mode()

    for epoch in range(start_epoch, args.epoch):

        correct_Noise = torch.tensor(0, dtype=torch.float32)
        correct_Real = torch.tensor(0, dtype=torch.float32)
        correct_Test = torch.tensor(0, dtype=torch.float32)
        total = torch.tensor(0, dtype=torch.float32)

        # train
        state_info.base.train()
        for it, (All, Ay, label_Ay) in enumerate(Noise_loader):

            All, Ay, label_Ay = to_var(All, FloatTensor), to_var(
                Ay, LongTensor), to_var(label_Ay, LongTensor)

            cls_out = state_info.forward_Base(All)

            state_info.optim_Base.zero_grad()
            loss = criterion(cls_out, Ay)
            loss.backward()
            state_info.optim_Base.step()

            _, pred = torch.max(cls_out.data, 1)
            correct_Noise += float(pred.eq(Ay.data).cpu().sum())
            correct_Real += float(pred.eq(label_Ay.data).cpu().sum())
            total += float(All.size(0))

            if it % 10 == 0:
                utils.print_log(
                    'main Train, {}, {}, {:.6f}, {:.3f}, {:.3f}'.format(
                        epoch, it, loss.item(), 100. * correct_Noise / total,
                        100. * correct_Real / total))
                print('main Train, {}, {}, {:.6f}, {:.3f}, {:.3f}'.format(
                    epoch, it, loss.item(), 100. * correct_Noise / total,
                    100. * correct_Real / total))

        total = torch.tensor(0, dtype=torch.float32)
        # test
        state_info.base.eval()
        for it, (Test, Ty, label_Ty) in enumerate(Test_loader):

            Test, Ty, label_Ty = to_var(Test, FloatTensor), to_var(
                Ty, LongTensor), to_var(label_Ty, LongTensor)

            cls_out = state_info.forward_Base(Test)

            _, pred = torch.max(cls_out.data, 1)
            correct_Test += float(pred.eq(Ty.data).cpu().sum())
            total += float(Noise.size(0))

        utils.print_log('main Test, {}, {}, {:.3f}'.format(
            epoch, it, 100. * correct_Test / total))
        print('main Test, {}, {}, {:.3f}'.format(epoch, it,
                                                 100. * correct_Test / total))

        if 100. * correct_Test / total > best_prec_result:
            best_prec_result = 100. * correct_Test / total
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, epoch,
                                        mode, filename,
                                        utils.default_model_dir)

        filename = 'latest.pth.tar'
        utils.save_state_checkpoint(state_info, best_prec_result, epoch, mode,
                                    filename, utils.default_model_dir)
        state_info.lr_Base.step()
        utils.print_log('')

    now = time.gmtime(time.time() - start_time)
    utils.print_log('Best Prec : {:.4f}'.format(best_prec_result.item()))
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))


# adversarial_loss = torch.nn.BCELoss()
# criterion_GAN = torch.nn.MSELoss()
# criterion_cycle = torch.nn.L1Loss()
# criterion_identity = torch.nn.L1Loss()
# criterion = nn.CrossEntropyLoss().cuda()
Exemplo n.º 5
0
def train_Triple(args, state_info, Noise_Triple_loader, Test_loader):  # all

    best_prec_result = torch.tensor(0, dtype=torch.float32)
    start_time = time.time()
    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
    mode = 'sample'
    utils.default_model_dir = os.path.join(args.dir, "triple")

    criterion = torch.nn.CrossEntropyLoss()
    softmax = torch.nn.Softmax(dim=1)
    # criterion_GAN = torch.nn.MSELoss()

    start_epoch = 0
    checkpoint = utils.load_checkpoint(utils.default_model_dir)
    if not checkpoint:
        args.last_epoch = -1
        state_info.learning_scheduler_init(args, mode)
    else:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint, mode)
        args.last_epoch = start_epoch
        state_info.learning_scheduler_init(args, mode)

    utils.print_log('Type, Epoch, Batch, loss, BCE, KLD, CE')

    for epoch in range(start_epoch, args.epoch):

        correct_Noise = torch.tensor(0, dtype=torch.float32)
        correct_Real = torch.tensor(0, dtype=torch.float32)
        correct_Test = torch.tensor(0, dtype=torch.float32)
        total = torch.tensor(0, dtype=torch.float32)

        # train
        state_info.sample.train()
        for it, (Sample, Sy, label_Sy) in enumerate(Noise_Triple_loader):

            Sample, Sy, label_Sy = to_var(Sample, FloatTensor), to_var(
                Sy, LongTensor), to_var(label_Sy, LongTensor)
            Ry = torch.randint_like(Sy, low=0, high=10)

            if args.grad == "T":
                weight = label_Sy.eq(Sy).type(FloatTensor).view(-1, 1)
                zero = torch.zeros(weight.size()).type(FloatTensor)
                reverse_weight = weight.eq(zero).type(LongTensor).view(-1)
                Gamma = WeightedGradientGamma(weight,
                                              low=args.low,
                                              high=args.high)

            elif args.grad == "F":
                Gamma = 1

            Sout = state_info.forward_Triple(Sample, gamma=1)

            alpha = 0.8
            beta = 0.2
            gamma = 0.5

            _, pred = torch.max(Sout.data, 1)
            state_info.optim_Sample.zero_grad()
            loss_Noise = alpha * criterion(Sout, Sy)
            loss_Pred = beta * criterion(Sout, pred)
            Reverse_log_P = torch.log(1 - softmax(Sout))
            loss_Ry = gamma * F.nll_loss(Reverse_log_P, Ry)
            loss = loss_Noise + loss_Pred + loss_Ry
            loss.backward()
            state_info.optim_Sample.step()

            # _, pred = torch.max(Sout.data, 1)
            correct_Noise += float(pred.eq(Sy.data).cpu().sum())
            correct_Real += float(pred.eq(label_Sy.data).cpu().sum())
            total += float(Sample.size(0))

            if it % 10 == 0:
                utils.print_log(
                    'main Train, {}, {}, {:.6f}, {:.3f}, {:.3f}'.format(
                        epoch, it, loss.item(), 100. * correct_Noise / total,
                        100. * correct_Real / total))
                print('main Train, {}, {}, {:.6f}, {:.3f}, {:.3f}'.format(
                    epoch, it, loss.item(), 100. * correct_Noise / total,
                    100. * correct_Real / total))

        # test
        state_info.sample.eval()
        total = torch.tensor(0, dtype=torch.float32)
        for it, (Noise, Ny, label_Ny) in enumerate(Noise_Test_loader):

            Noise, Ny, label_Ny = to_var(Noise, FloatTensor), to_var(
                Ny, LongTensor), to_var(label_Ny, LongTensor)

            Nout = state_info.forward_Sample(Noise, gamma=1)

            _, pred = torch.max(Nout.data, 1)
            correct_Test += float(pred.eq(label_Ny.data).cpu().sum())
            total += float(Noise.size(0))

        utils.print_log('main Test, {}, {}, {:.3f}'.format(
            epoch, it, 100. * correct_Test / total))
        print('main Test, {}, {}, {:.3f}'.format(epoch, it,
                                                 100. * correct_Test / total))

        if 100. * correct_Test / total > best_prec_result:
            best_prec_result = 100. * correct_Test / total
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, epoch,
                                        mode, filename,
                                        utils.default_model_dir)

        filename = 'latest.pth.tar'
        utils.save_state_checkpoint(state_info, best_prec_result, epoch, mode,
                                    filename, utils.default_model_dir)
        state_info.lr_Sample.step()
        utils.print_log('')

    now = time.gmtime(time.time() - start_time)
    utils.print_log('Best Prec : {:.4f}'.format(best_prec_result.item()))
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))


# adversarial_loss = torch.nn.BCELoss()
# criterion_GAN = torch.nn.MSELoss()
# criterion_cycle = torch.nn.L1Loss()
# criterion_identity = torch.nn.L1Loss()
# criterion = nn.CrossEntropyLoss().cuda()
Exemplo n.º 6
0
def train_Disc2(args, state_info, True_loader, Fake_loader,
                Noise_Test_loader):  # all

    best_prec_result = torch.tensor(0, dtype=torch.float32)
    start_time = time.time()
    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
    mode = 'disc'
    utils.default_model_dir = os.path.join(args.dir, mode)

    criterion_GAN = torch.nn.BCELoss()
    criterion = torch.nn.CrossEntropyLoss()
    softmax = torch.nn.Softmax(dim=1)
    # criterion_GAN = torch.nn.MSELoss()

    percentage = get_percentage_Fake(Fake_loader)

    start_epoch = 0
    checkpoint = utils.load_checkpoint(utils.default_model_dir)
    if not checkpoint:
        args.last_epoch = -1
        state_info.learning_scheduler_init(args, mode)
    else:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint, mode)
        args.last_epoch = start_epoch
        state_info.learning_scheduler_init(args, mode)

    utils.print_log('Type, Epoch, Batch, loss, BCE, KLD, CE')

    for epoch in range(start_epoch, args.epoch):

        correct_Noise = torch.tensor(0, dtype=torch.float32)
        correct_Real = torch.tensor(0, dtype=torch.float32)
        correct_Test = torch.tensor(0, dtype=torch.float32)
        total = torch.tensor(0, dtype=torch.float32)

        # train
        state_info.disc.train()
        for it, ((real, Ry, label_Ry),
                 (fake, Fy,
                  label_Fy)) in enumerate(zip(True_loader, Fake_loader)):

            real, Ry, label_Ry = to_var(real, FloatTensor), to_var(
                Ry, LongTensor), to_var(label_Ry, LongTensor)
            fake, Fy, label_Fy = to_var(fake, FloatTensor), to_var(
                Fy, LongTensor), to_var(label_Fy, LongTensor)

            Rout, Fout = state_info.forward_disc2(
                real, gamma=1), state_info.forward_disc2(fake, gamma=-1),

            state_info.optim_Disc.zero_grad()
            loss = criterion(Rout, Ry)
            loss.backward()
            state_info.optim_Disc.step()

            # state_info.optim_Disc.zero_grad()
            # loss_reverse = criterion(Fout, Fy) * 0.1
            # loss_reverse.backward()
            # state_info.optim_Disc.step()

            _, predR = torch.max(Rout.data, 1)
            _, predF = torch.max(Fout.data, 1)
            correct_Noise += float(predF.eq(label_Fy.data).cpu().sum())
            correct_Real += float(predR.eq(label_Ry.data).cpu().sum())
            total += float(real.size(0))

            if it % 10 == 0:
                utils.print_log(
                    'main Train, {}, {}, {:.6f}, {:.6f}, {:.3f}, {:.3f}'.
                    format(epoch, it, loss.item(), loss.item(),
                           100. * correct_Noise / total,
                           100. * correct_Real / total))
                print('main Train, {}, {}, {:.6f}, {:.6f}, {:.3f}, {:.3f}'.
                      format(epoch, it, loss.item(), loss.item(),
                             100. * correct_Noise / total,
                             100. * correct_Real / total))

        # test
        state_info.disc.eval()
        total = torch.tensor(0, dtype=torch.float32)
        for it, (Noise, Ny, label_Ny) in enumerate(Noise_Test_loader):

            Noise, Ny, label_Ny = to_var(Noise, FloatTensor), to_var(
                Ny, LongTensor), to_var(label_Ny, LongTensor)

            Nout = state_info.forward_disc2(Noise, gamma=1)

            label_one = torch.cuda.FloatTensor(Noise.size(0),
                                               10).zero_().scatter_(
                                                   1, Ny.view(-1, 1), 1)
            weight = (softmax(Nout) * label_one).sum(dim=1)
            print('1', softmax(Nout))
            print('2', label_one)
            print('3', weight)
            print(error)

            _, pred = torch.max(Nout.data, 1)
            correct_Test += float(pred.eq(label_Ny.data).cpu().sum())
            total += float(Noise.size(0))

        utils.print_log('main Test, {}, {}, {:.3f}'.format(
            epoch, it, 100. * correct_Test / total))
        print('main Test, {}, {}, {:.3f}'.format(epoch, it,
                                                 100. * correct_Test / total))

        if 100. * correct_Test / total > best_prec_result:
            best_prec_result = 100. * correct_Test / total
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, epoch,
                                        mode, filename,
                                        utils.default_model_dir)

        filename = 'latest.pth.tar'
        utils.save_state_checkpoint(state_info, best_prec_result, epoch, mode,
                                    filename, utils.default_model_dir)
        state_info.lr_Disc.step()
        utils.print_log('')

    now = time.gmtime(time.time() - start_time)
    utils.print_log('Best Prec : {:.4f}'.format(best_prec_result.item()))
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))


# adversarial_loss = torch.nn.BCELoss()
# criterion_GAN = torch.nn.MSELoss()
# criterion_cycle = torch.nn.L1Loss()
# criterion_identity = torch.nn.L1Loss()
# criterion = nn.CrossEntropyLoss().cuda()
Exemplo n.º 7
0
def main():
    global args, best_prec_result

    utils.default_model_dir = args.dir
    start_time = time.time()

    Source_train_loader, Source_test_loader = dataset_selector(args.sd)
    Target_train_loader, Target_test_loader = dataset_selector(args.td)

    state_info = utils.model_optim_state_info()
    state_info.model_init()
    state_info.model_cuda_init()

    if cuda:
        # os.environ["CUDA_VISIBLE_DEVICES"] = '0'
        print("USE", torch.cuda.device_count(), "GPUs!")
        state_info.weight_cuda_init()
        cudnn.benchmark = True
    else:
        print("NO GPU")

    state_info.optimizer_init(lr=args.lr,
                              b1=args.b1,
                              b2=args.b2,
                              weight_decay=args.weight_decay)

    adversarial_loss = torch.nn.BCELoss()
    criterion = nn.CrossEntropyLoss().cuda()

    start_epoch = 0

    utils.default_model_dir
    filename = 'latest.pth.tar'
    checkpoint = utils.load_checkpoint(utils.default_model_dir)
    if not checkpoint:
        pass
    else:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint)

    numEpochs = int(
        math.ceil(
            float(args.train_iters) /
            float(min(len(Source_train_loader), len(Target_train_loader)))))

    for epoch in range(numEpochs):
        # if epoch < 80:
        #     learning_rate = args.lr
        # elif epoch < 120:
        #     learning_rate = args.lr * 0.1
        # else:
        #     learning_rate = args.lr * 0.01
        # for param_group in optimizer.param_groups:
        #     param_group['lr'] = learning_rate

        train(state_info, Source_train_loader, Target_train_loader, criterion,
              adversarial_loss, epoch)
        prec_result = test(state_info, Source_test_loader, Target_test_loader,
                           criterion, epoch)

        if prec_result > best_prec_result:
            best_prec_result = prec_result
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, filename,
                                        utils.default_model_dir, epoch)

        if epoch % 5 == 0:
            filename = 'latest.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, filename,
                                        utils.default_model_dir, epoch)

    now = time.gmtime(time.time() - start_time)
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))
Exemplo n.º 8
0
def train_Disc(args, state_info, True_loader, Fake_loader,
               Noise_Test_loader):  # all

    best_prec_result = torch.tensor(0, dtype=torch.float32)
    start_time = time.time()
    cuda = True if torch.cuda.is_available() else False
    FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor
    mode = 'disc'
    utils.default_model_dir = os.path.join(args.dir, mode)

    criterion_GAN = torch.nn.BCELoss()
    # criterion = torch.nn.CrossEntropyLoss()
    # softmax = torch.nn.Softmax(dim=1)
    # criterion_GAN = torch.nn.MSELoss()

    percentage = get_percentage_Fake(Fake_loader)

    start_epoch = 0
    checkpoint = utils.load_checkpoint(utils.default_model_dir)
    if not checkpoint:
        args.last_epoch = -1
        state_info.learning_scheduler_init(args, mode)
    else:
        start_epoch = checkpoint['epoch'] + 1
        best_prec_result = checkpoint['Best_Prec']
        state_info.load_state_dict(checkpoint, mode)
        args.last_epoch = start_epoch
        state_info.learning_scheduler_init(args, mode)

    utils.print_log('Type, Epoch, Batch, loss, BCE, KLD, CE')

    for epoch in range(start_epoch, args.d_epoch):

        correctR = torch.tensor(0, dtype=torch.float32)
        correctF = torch.tensor(0, dtype=torch.float32)
        correctN = torch.tensor(0, dtype=torch.float32)
        total = torch.tensor(0, dtype=torch.float32)

        # train
        state_info.disc.train()
        for it, ((real, Ry, label_Ry),
                 (fake, Fy,
                  label_Fy)) in enumerate(zip(True_loader, Fake_loader)):

            valid = Variable(FloatTensor(real.size(0), 1).fill_(1.0),
                             requires_grad=False)
            unvalid = Variable(FloatTensor(fake.size(0), 1).fill_(0.0),
                               requires_grad=False)
            # valid = Variable(LongTensor(real.size(0)).fill_(1), requires_grad=False)
            # unvalid = Variable(LongTensor(fake.size(0)).fill_(0), requires_grad=False)

            real, Ry, label_Ry = to_var(real, FloatTensor), to_var(
                Ry, LongTensor), to_var(label_Ry, LongTensor)
            fake, Fy, label_Fy = to_var(fake, FloatTensor), to_var(
                Fy, LongTensor), to_var(label_Fy, LongTensor)

            Rout, Fout = state_info.forward_disc(real,
                                                 Ry), state_info.forward_disc(
                                                     fake, Fy)

            state_info.optim_Disc.zero_grad()
            loss_real = criterion_GAN(Rout, valid)
            loss_fake = criterion_GAN(Fout, unvalid)
            loss_Disc = (loss_real + loss_fake) / 2
            loss_Disc.backward()
            state_info.optim_Disc.step()

            # _, predR = torch.max(Rout.data, 1)
            # resultR = label_Ry.eq(Ry).cpu().type(torch.LongTensor).view(-1,1)

            # _, predF = torch.max(Fout.data, 1)
            # resultF = label_Fy.eq(Fy).cpu().type(torch.LongTensor).view(-1,1)

            # correctR += float(predR.view(-1,1).cpu().eq(resultR.data).cpu().sum())
            # correctF += float(predF.view(-1,1).cpu().eq(resultF.data).cpu().sum())

            resultR = label_Ry.eq(Ry).cpu().type(torch.ByteTensor).view(-1, 1)
            predR = torch.round(Rout).cpu().type(torch.ByteTensor)

            resultF = label_Fy.eq(Fy).cpu().type(torch.ByteTensor).view(-1, 1)
            predF = torch.round(Fout).cpu().type(torch.ByteTensor)

            correctR += float(predR.eq(resultR.data).cpu().sum())
            correctF += float(predF.eq(resultF.data).cpu().sum())

            total += float(real.size(0))

            if it % 10 == 0:

                utils.print_log(
                    'Disc Train, {}, {}, {:.6f}, {:.6f}, {:.3f}, {:.3f}'.
                    format(epoch, it, loss_real.item(), loss_fake.item(),
                           100. * correctR / total, 100. * correctF / total))
                print('Disc Train, {}, {}, {:.6f}, {:.6f}, {:.3f}, {:.3f}'.
                      format(epoch, it, loss_real.item(), loss_fake.item(),
                             100. * correctR / total, 100. * correctF / total))

        # test
        state_info.disc.eval()
        total = torch.tensor(0, dtype=torch.float32)
        for it, (Noise, Ny, label_Ny) in enumerate(Noise_Test_loader):

            Noise, Ny, label_Ny = to_var(Noise, FloatTensor), to_var(
                Ny, LongTensor), to_var(label_Ny, LongTensor)

            Nout = state_info.forward_disc(Noise, Ny)

            # resultN = label_Ny.eq(Ny).cpu().type(torch.LongTensor).view(-1,1)
            # _, predN = torch.max(Nout.data, 1)

            resultN = label_Ny.eq(Ny).cpu().type(torch.ByteTensor).view(-1, 1)
            predN = torch.round(Nout).cpu().type(torch.ByteTensor)

            correctN += float(
                predN.view(-1, 1).cpu().eq(resultN.data).cpu().sum())
            total += float(Noise.size(0))

        utils.print_log('Disc Test, {}, {}, {:.3f}'.format(
            epoch, it, 100. * correctN / total))
        print('Disc Test, {}, {}, {:.3f}'.format(epoch, it,
                                                 100. * correctN / total))

        if 100. * correctN / total > best_prec_result:
            best_prec_result = 100. * correctN / total
            filename = 'checkpoint_best.pth.tar'
            utils.save_state_checkpoint(state_info, best_prec_result, epoch,
                                        mode, filename,
                                        utils.default_model_dir)

        filename = 'latest.pth.tar'
        utils.save_state_checkpoint(state_info, best_prec_result, epoch, mode,
                                    filename, utils.default_model_dir)
        state_info.lr_Disc.step()
        utils.print_log('')

    now = time.gmtime(time.time() - start_time)
    utils.print_log('Best Prec : {:.4f}'.format(best_prec_result.item()))
    utils.print_log('{} hours {} mins {} secs for training'.format(
        now.tm_hour, now.tm_min, now.tm_sec))


# adversarial_loss = torch.nn.BCELoss()
# criterion_GAN = torch.nn.MSELoss()
# criterion_cycle = torch.nn.L1Loss()
# criterion_identity = torch.nn.L1Loss()
# criterion = nn.CrossEntropyLoss().cuda()