Esempio n. 1
0
def main(args=args):
    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    zca_mean = None
    zca_components = None

    # build dataset
    if args.dataset == "Cifar10":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar10_dataset(dataset_base_path)
        test_dataset = cifar10_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 500, 400,
            10)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        num_classes = 10
        if args.zca:
            zca_mean = np.load(
                os.path.join(dataset_base_path, 'cifar10_zca_mean.npy'))
            zca_components = np.load(
                os.path.join(dataset_base_path, 'cifar10_zca_components.npy'))
            zca_mean = torch.from_numpy(zca_mean).view(1, -1).float().cuda()
            zca_components = torch.from_numpy(zca_components).float().cuda()
    elif args.dataset == "Cifar100":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar100_dataset(dataset_base_path)
        test_dataset = cifar100_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 50, 40,
            100)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        num_classes = 100
    elif args.dataset == "SVHN":
        dataset_base_path = path.join(args.base_path, "dataset", "svhn")
        train_dataset = svhn_dataset(dataset_base_path)
        test_dataset = svhn_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.labels, dtype=torch.int32), 732, 100,
            10)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        num_classes = 10
    else:
        raise NotImplementedError("Dataset {} Not Implemented".format(
            args.dataset))
    if args.net_name == "wideresnet":
        model = wideresnet.WideResNet(depth=args.depth,
                                      width=args.width,
                                      num_classes=num_classes,
                                      data_parallel=args.dp,
                                      drop_rate=args.dr)
    elif "preact" in args.net_name:
        model = get_preact_resnet(args.net_name,
                                  num_classes=num_classes,
                                  data_parallel=args.dp,
                                  drop_rate=args.dr)
    elif "densenet" in args.net_name:
        model = get_densenet(args.net_name,
                             num_classes=num_classes,
                             data_parallel=args.dp,
                             drop_rate=args.dr)
    else:
        raise NotImplementedError("model {} not implemented".format(
            args.net_name))
    model = model.cuda()

    input(
        "Begin the {} time's semi-supervised training, Dataset:{} Mixup Method:{} \
    Manifold Mixup Method :{}".format(args.train_time, args.dataset,
                                      args.mixup, args.manifold_mixup))
    criterion_l = nn.CrossEntropyLoss()
    criterion_u = nn.MSELoss()
    if args.optimizer == "SGD":
        optimizer = torch.optim.SGD(model.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.wd,
                                    nesterov=args.nesterov)
    else:
        raise NotImplementedError("{} not find".format(args.optimizer))
    scheduler = MultiStepLR(optimizer,
                            milestones=args.adjust_lr,
                            gamma=args.lr_decay_ratio)
    writer_log_dir = "{}/{}-SSL/runs/train_time:{}".format(
        args.base_path, args.dataset, args.train_time)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.resume_arg:
                args = checkpoint['args']
                args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError(
                "Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input(
                "{}-SSL train_time:{} will be removed, input yes to continue:".
                format(args.dataset, args.train_time))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    for epoch in range(args.start_epoch, args.epochs):
        scheduler.step(epoch)
        if epoch == 0:
            # do warm up
            modify_lr_rate(opt=optimizer, lr=args.wul)
        alpha = alpha_schedule(epoch=epoch)
        train(train_dloader_l,
              train_dloader_u,
              model=model,
              criterion_l=criterion_l,
              criterion_u=criterion_u,
              optimizer=optimizer,
              epoch=epoch,
              writer=writer,
              alpha=alpha,
              zca_mean=zca_mean,
              zca_components=zca_components)
        test(valid_dloader,
             test_dloader,
             model=model,
             criterion=criterion_l,
             epoch=epoch,
             writer=writer,
             num_classes=num_classes,
             zca_mean=zca_mean,
             zca_components=zca_components)
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            "state_dict": model.state_dict(),
            'optimizer': optimizer.state_dict(),
        })
        if epoch == 0:
            modify_lr_rate(opt=optimizer, lr=args.lr)
Esempio n. 2
0
def main(args=args):
    if args.dataset == "Cifar10":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar10_dataset(dataset_base_path)
        test_dataset = cifar10_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_cifar10_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 500,
            round(4000 * args.annotated_ratio), 10)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        input_channels = 3
        num_classes = 10
        small_input = True
        criterion = torch.nn.CrossEntropyLoss()
    elif args.dataset == "Cifar100":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar100_dataset(dataset_base_path)
        test_dataset = cifar100_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_cifar100_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 50,
            round(400 * args.annotated_ratio), 100)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        input_channels = 3
        num_classes = 100
        small_input = True
        criterion = torch.nn.CrossEntropyLoss()
    elif args.dataset == "SVHN":
        dataset_base_path = path.join(args.base_path, "dataset", "svhn")
        train_dataset = svhn_dataset(dataset_base_path)
        test_dataset = svhn_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.labels, dtype=torch.int32), 100, 100,
            10)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset,
                                     batch_size=args.batch_size,
                                     num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        input_channels = 3
        num_classes = 10
        small_input = True
        criterion = torch.nn.CrossEntropyLoss()
    else:
        raise NotImplementedError("Dataset {} Not Implemented".format(
            args.dataset))
    model = get_wide_resnet(args.net_name,
                            args.drop_rate,
                            input_channels=input_channels,
                            small_input=small_input,
                            data_parallel=args.dp,
                            num_classes=num_classes)
    model = model.cuda()
    print(
        "Begin the {} Time's Training Semi-Supervised Classifiers, Dataset {}".
        format(args.train_time, args.dataset))
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.beta1,
                                weight_decay=args.wd)
    scheduler = MultiStepLR(optimizer,
                            milestones=args.adjust_lr,
                            gamma=args.lr_decay_ratio)
    writer_log_dir = "{}/{}-SSL-Classifier/runs/train_time:{}".format(
        args.base_path, args.dataset, args.train_time)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args = checkpoint['args']
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError(
                "Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input(
                "vae_train_time:{} will be removed, input yes to continue:".
                format(args.train_time))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    for epoch in range(args.start_epoch, args.epochs):
        if epoch == 0:
            # do warm up
            modify_lr_rate(opt=optimizer, lr=args.lr * 0.2)
        train(train_dloader_l,
              model=model,
              criterion=criterion,
              optimizer=optimizer,
              epoch=epoch,
              writer=writer)
        test(valid_dloader,
             test_dloader,
             model=model,
             criterion=criterion,
             epoch=epoch,
             writer=writer,
             num_classes=num_classes)
        if epoch == 0:
            modify_lr_rate(opt=optimizer, lr=args.lr)
        scheduler.step(epoch)
Esempio n. 3
0
def main(args=args):
    if args.dataset == "Cifar10":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar10_dataset(dataset_base_path)
        test_dataset = cifar10_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_cifar10_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 500, round(4000 * args.annotated_ratio), 10)
        test_dloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True)
        valid_dloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        input_channels = 3
        small_input = True
        discrete_latent_dim = 10
        args.cmi = 200
        args.dmi = 2.3
        elbo_criterion = VAECriterion(discrete_dim=discrete_latent_dim, x_sigma=args.x_sigma,
                                      bce_reconstruction=args.br).cuda()
        cls_criterion = ClsCriterion()
    elif args.dataset == "Cifar100":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar100_dataset(dataset_base_path)
        test_dataset = cifar100_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_cifar100_ssl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 50, round(400 * args.annotated_ratio), 100)
        test_dloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True)
        valid_dloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        input_channels = 3
        small_input = True
        discrete_latent_dim = 100
        args.cmi = 1280
        args.dmi = 4.6
        elbo_criterion = VAECriterion(discrete_dim=discrete_latent_dim, x_sigma=args.x_sigma,
                                      bce_reconstruction=args.br).cuda()
        cls_criterion = ClsCriterion()
    elif args.dataset == "SVHN":
        dataset_base_path = path.join(args.base_path, "dataset", "svhn")
        train_dataset = svhn_dataset(dataset_base_path)
        test_dataset = svhn_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train_l, sampler_train_u = get_ssl_sampler(
            torch.tensor(train_dataset.labels, dtype=torch.int32), 100, 100, 10)
        test_dloader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True)
        valid_dloader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader_l = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_l)
        train_dloader_u = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.workers,
                                     pin_memory=True,
                                     sampler=sampler_train_u)
        input_channels = 3
        small_input = True
        discrete_latent_dim = 10
        elbo_criterion = VAECriterion(discrete_dim=discrete_latent_dim, x_sigma=args.x_sigma,
                                      bce_reconstruction=args.br).cuda()
        cls_criterion = ClsCriterion()
    else:
        raise NotImplementedError("Dataset {} not implemented".format(args.dataset))
    model = VariationalAutoEncoder(encoder_name=args.net_name, num_input_channels=input_channels,
                                   drop_rate=args.drop_rate, img_size=tuple(args.image_size), data_parallel=args.dp,
                                   continuous_latent_dim=args.ldc, disc_latent_dim=discrete_latent_dim,
                                   sample_temperature=args.temperature, small_input=small_input)
    model = model.cuda()
    print("Begin the {} Time's Training Semi-Supervised VAE, Dataset {}".format(args.train_time, args.dataset))
    optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.beta1, weight_decay=args.wd)
    scheduler = MultiStepLR(optimizer, milestones=args.adjust_lr)
    writer_log_dir = "{}/{}-M2-VAE/runs/train_time:{}".format(args.base_path, args.dataset,
                                                                args.train_time)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args = checkpoint['args']
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError("Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input("vae_train_time:{} will be removed, input yes to continue:".format(
                args.train_time))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    best_valid_acc = 10
    for epoch in range(args.start_epoch, args.epochs):
        if epoch == 0:
            # do warm up
            modify_lr_rate(opt=optimizer, lr=args.lr * 0.2)
        train(train_dloader_u, train_dloader_l, model=model, elbo_criterion=elbo_criterion, cls_criterion=cls_criterion,
              optimizer=optimizer, epoch=epoch,
              writer=writer, discrete_latent_dim=discrete_latent_dim)
        elbo_valid_loss, *_ = valid(valid_dloader, model=model, elbo_criterion=elbo_criterion, epoch=epoch,
                                    writer=writer, discrete_latent_dim=discrete_latent_dim)
        if test_dloader is not None:
            test(test_dloader, model=model, elbo_criterion=elbo_criterion,epoch=epoch,
                 writer=writer, discrete_latent_dim=discrete_latent_dim)
        """
        Here we define the best point as the minimum average epoch loss
        """
        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            "state_dict": model.state_dict(),
            'optimizer': optimizer.state_dict(),
        })
        if elbo_valid_loss < best_valid_acc:
            best_valid_acc = elbo_valid_loss
            if epoch >= args.adjust_lr[-1]:
                save_checkpoint({
                    'epoch': epoch + 1,
                    'args': args,
                    "state_dict": model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, best_predict=True)
        scheduler.step(epoch)
        if epoch == 0:
            modify_lr_rate(opt=optimizer, lr=args.lr)
Esempio n. 4
0
def main(args=args):
    if args.dataset == "Cifar10":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar10_dataset(dataset_base_path)
        test_dataset = cifar10_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train = get_cifar10_sl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 500, 10)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_train)
        input_channels = 3
        small_input = True
    elif args.dataset == "Cifar100":
        dataset_base_path = path.join(args.base_path, "dataset", "cifar")
        train_dataset = cifar100_dataset(dataset_base_path)
        test_dataset = cifar100_dataset(dataset_base_path, train_flag=False)
        sampler_valid, sampler_train = get_cifar100_sl_sampler(
            torch.tensor(train_dataset.targets, dtype=torch.int32), 50, 100)
        test_dloader = DataLoader(test_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=args.workers,
                                  pin_memory=True)
        valid_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_valid)
        train_dloader = DataLoader(train_dataset,
                                   batch_size=args.batch_size,
                                   num_workers=args.workers,
                                   pin_memory=True,
                                   sampler=sampler_train)
        input_channels = 3
        small_input = True
    else:
        raise NotImplementedError("Dataset {} not implemented".format(
            args.dataset))
    model = VariationalAutoEncoder(num_input_channels=input_channels,
                                   encoder_name=args.net_name,
                                   drop_rate=args.drop_rate,
                                   img_size=tuple(args.image_size),
                                   data_parallel=args.dp,
                                   continuous_latent_dim=args.ldc,
                                   disc_latent_dim=args.ldd,
                                   sample_temperature=args.temperature,
                                   small_input=small_input)
    model = model.cuda()
    input("Begin the {} time's training, Dataset {}".format(
        args.train_time, args.dataset))
    elbo_criterion = VAECriterion(discrete_dim=args.ldd,
                                  x_sigma=args.x_sigma,
                                  bce_reconstruction=args.br).cuda()
    if args.mixup:
        kl_disc_criterion = KLDiscCriterion().cuda()
        kl_norm_criterion = KLNormCriterion().cuda()
    else:
        kl_disc_criterion = None
        kl_norm_criterion = None
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 betas=(args.beta1, 0.999))
    scheduler = MultiStepLR(optimizer, milestones=args.adjust_lr)
    writer_log_dir = "{}/{}-VAE/runs/train_time:{}".format(
        args.base_path, args.dataset, args.train_time)
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            if args.resume_arg:
                args = checkpoint['args']
                args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            raise FileNotFoundError(
                "Checkpoint Resume File {} Not Found".format(args.resume))
    else:
        if os.path.exists(writer_log_dir):
            flag = input(
                "vae_train_time:{} will be removed, input yes to continue:".
                format(args.train_time))
            if flag == "yes":
                shutil.rmtree(writer_log_dir, ignore_errors=True)
    writer = SummaryWriter(log_dir=writer_log_dir)
    best_valid_loss = 1e10
    for epoch in range(args.start_epoch, args.epochs):

        train(train_dloader,
              model=model,
              elbo_criterion=elbo_criterion,
              optimizer=optimizer,
              epoch=epoch,
              writer=writer,
              kl_norm_criterion=kl_norm_criterion,
              kl_disc_criterion=kl_disc_criterion)
        elbo_valid_loss, *_ = valid(valid_dloader,
                                    model=model,
                                    elbo_criterion=elbo_criterion,
                                    epoch=epoch,
                                    writer=writer)
        if test_dloader is not None:
            test(test_dloader,
                 model=model,
                 elbo_criterion=elbo_criterion,
                 epoch=epoch,
                 writer=writer)

        save_checkpoint({
            'epoch': epoch + 1,
            'args': args,
            "state_dict": model.state_dict(),
            'optimizer': optimizer.state_dict(),
        })
        if elbo_valid_loss < best_valid_loss:
            best_valid_loss = elbo_valid_loss
            if epoch >= args.adjust_lr[-1]:
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'args': args,
                        "state_dict": model.state_dict(),
                        'optimizer': optimizer.state_dict()
                    },
                    best_predict=True)
        scheduler.step(epoch)