Example #1
0
    def train(self):
        tt = time.time()
        best_loss = 1
        self.model.train()
        self.model.to(self.device)

        self.writer = SummaryWriter(comment='MNIST')
        print("*******Start training*******")
        for epoch in range(self.epochs):
            print(f"---Epoch:{epoch+1}/{self.epochs}---")
            train_metrics = self._train(epoch)
            valid_metrics = self._valid(epoch)
            self.model.train()

            # save model
            is_best = valid_metrics['valid_loss'] < best_loss
            state = {'epoch': epoch,
                     'model': self.model.state_dict(),
                     'optimozer': self.optimizer.state_dict(),
                     'best_loss': best_loss}
            save_checkpoints(state, is_best, save_dir='MNIST')
        time_ellapser = time.time()-tt
        self.writer.close()
        print(
            f'Training complete in {time_ellapsed//60}m {time_ellapsed % 60}s')
 def save_model(self, name, epoch):
     save_checkpoints(
         self.generator,
         name + 'G',
         epoch,
         optimizer=self.optimizers['G'],
     )
     save_checkpoints(self.discriminator,
                      name + 'D',
                      epoch,
                      optimizer=self.optimizers['D'])
Example #3
0
def main():
    #定义网络
    net = models.LeNetWithAngle(classes_num)
    if use_gpu:
        net = net.cuda()
    #定义优化器
    optimizer = torch.optim.SGD(net.parameters(),
                                lr=model_lr,
                                weight_decay=1e-5,
                                nesterov=True,
                                momentum=0.9)
    print("net and optimzer load succeed")
    #定义数据加载
    trainloader, testloader = dataloader.get_loader(batch_size=batch_size,
                                                    root_path="./data/MNIST")
    print("data load succeed")
    #定义logger
    logger = utils.Logger(tb_path="./logs/tblog/")
    #定义学习率调整器
    scheduler = lr_sche.StepLR(optimizer, 30, 0.1)
    #定义损失函数
    criterion = a_softmax.AngleSoftmaxLoss(gamma=0)
    best_acc = 0
    #开始训练
    for i in range(1, epochs + 1):
        scheduler.step(epoch=i)
        net.train()
        train_acc,train_loss,all_feature,all_labels=\
            train(net,optimizer,criterion,trainloader,i)
        utils.plot_features(all_feature, all_labels, classes_num, i,
                            "./logs/images/train/train_{}.png")
        net.eval()
        test_acc, test_loss, all_feature, all_labels = test(
            net, criterion, testloader, i)
        utils.plot_features(all_feature, all_labels, classes_num, i,
                            "./logs/images/test/test_{}.png")
        print("{} epoch end, train acc is {:.4f}, test acc is {:.4f}".format(
            i, train_acc, test_acc))
        content = {
            "Train/acc": train_acc,
            "Test/acc": test_acc,
            "Train/loss": train_loss,
            "Test/loss": test_loss
        }
        logger.log(step=i, content=content)
        if best_acc < test_acc:
            best_acc = test_acc
        utils.save_checkpoints("./logs/weights/net_{}.pth",i,\
            net.state_dict(),(best_acc==test_acc))
    utils.make_gif("./logs/images/train/", "./logs/train.gif")
    utils.make_gif("./logs/images/test/", "./logs/test.gif")
    print("Traing finished...")
Example #4
0
def run():
    args = argparser()

    path = utils.create_log_dir(sys.argv)
    utils.start(args.http_port)

    env = Env(args)
    agents = [Agent(args) for _ in range(args.n_agent)]
    master = Master(args)

    for agent in agents:
        master.add_agent(agent)
    master.add_env(env)

    success_list = []
    time_list = []

    for idx in range(args.n_episode):
        print('=' * 80)
        print("Episode {}".format(idx + 1))
        # 서버의 stack, timer 초기화
        print("서버를 초기화하는중...")
        master.reset(path)

        # 에피소드 시작
        master.start()
        # 에이전트 학습
        master.train()
        print('=' * 80)
        success_list.append(master.infos["is_success"])
        time_list.append(master.infos["end_time"] - master.infos["start_time"])

        if (idx + 1) % args.print_interval == 0:
            print("=" * 80)
            print("EPISODE {}: Avg. Success Rate / Time: {:.2} / {:.2}".format(
                idx + 1, np.mean(success_list), np.mean(time_list)))
            success_list.clear()
            time_list.clear()
            print("=" * 80)

        if (idx + 1) % args.checkpoint_interval == 0:
            utils.save_checkpoints(path, agents, idx + 1)

    if args.visual:
        visualize(path, args)
    print("끝")
    utils.close()
Example #5
0
def train(model, optimizer, loss_fn, dataloader):
    """ Train the model on `num_steps` batches
    Args:
        model : (torch.nn.Module) model
        optimizer : (torch.optim) optimizer for parameters of model
        loss_fn : (string) a function that takes batch_output and batch_labels and computes the loss for the batch
        dataloader : (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        num_steps : (int) # of batches to train on, each of size args.batch_size
    """

    # set model to training mode
    model.train()

    model_dir = './results/' + model_name
    best_acc = 0.0

    for epoch in range(epochs):
        epoch_loss = 0.0
        epoch_correct = 0.0

        for i, (train_batch, labels_batch) in enumerate(dataloader):
            # move to GPU if available
            if args.gpu:
                train_batch, labels_batch = train_batch.cuda(
                ), labels_batch.cuda()

            # convert to torch Variable
            train_batch, labels_batch = Variable(train_batch), Variable(
                labels_batch)

            # compute model output and loss
            output_batch = model(train_batch)
            loss = loss_fn(output_batch, labels_batch)

            # clear previous gradients, compute gradients of all variables wrt loss
            optimizer.zero_grad()
            loss.backward()

            # performs updates using calculated gradients
            optimizer.step()

            epoch_loss += loss.item()
            acc = utils.accuracy(output_batch.data.cpu().numpy(),
                                 labels_batch.data.cpu().numpy())
            epoch_correct += acc


#             print("Epoch [{}]\t Batch [{}/{}]\t Loss:{:.4f}\t Accuracy:{:.4f}".format(epoch+1, i, len(dataloader), loss.item(), acc))

        print("Epoch [{}/{}]\t Loss:{:.4f}\t Accuracy:{:.4f}%".format(
            epoch + 1, epochs, epoch_loss / len(dataloader),
            100 * epoch_correct / len(dataloader)))

        is_best = acc >= best_acc
        if is_best:
            logging.info("- Found new best accuracy")
            best_acc = acc

        utils.save_checkpoints(
            {
                'epoch': i + 1,
                'state_dict': model.state_dict(),
                'optim_dict': optimizer.state_dict()
            },
            is_best=is_best,
            checkpoint=model_dir)
Example #6
0
def main():
    logger = logging.getLogger('tl')
    args = get_args()
    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    def _rescale(img):
        return img * 2.0 - 1.0

    def _noise_adder(img):
        return torch.empty_like(img, dtype=img.dtype).uniform_(0.0,
                                                               1 / 128.0) + img

    # dataset
    dataset_cfg = global_cfg.get('dataset_cfg', {})
    dataset_module = dataset_cfg.get('dataset_module', 'datasets.ImageFolder')

    train_dataset = eval(dataset_module)(os.path.join(args.data_root),
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             _rescale,
                                             _noise_adder,
                                         ]),
                                         **dataset_cfg.get(
                                             'dataset_kwargs', {}))
    train_loader = iter(
        data.DataLoader(train_dataset,
                        args.batch_size,
                        sampler=InfiniteSamplerWrapper(train_dataset),
                        num_workers=args.num_workers,
                        pin_memory=False))
    if args.calc_FID:
        eval_dataset = datasets.ImageFolder(
            os.path.join(args.data_root, 'val'),
            transforms.Compose([
                transforms.ToTensor(),
                _rescale,
            ]))
        eval_loader = iter(
            data.DataLoader(eval_dataset,
                            args.batch_size,
                            sampler=InfiniteSamplerWrapper(eval_dataset),
                            num_workers=args.num_workers,
                            pin_memory=True))
    else:
        eval_loader = None
    num_classes = len(train_dataset.classes)
    print(' prepared datasets...')
    print(' Number of training images: {}'.format(len(train_dataset)))
    # Prepare directories.
    args.num_classes = num_classes
    args, writer = prepare_results_dir(args)
    # initialize models.
    _n_cls = num_classes if args.cGAN else 0

    gen_module = getattr(
        global_cfg.generator, 'module',
        'pytorch_sngan_projection_lib.models.generators.resnet64')
    model_module = importlib.import_module(gen_module)

    gen = model_module.ResNetGenerator(
        args.gen_num_features,
        args.gen_dim_z,
        args.gen_bottom_width,
        activation=F.relu,
        num_classes=_n_cls,
        distribution=args.gen_distribution).to(device)

    if args.dis_arch_concat:
        dis = SNResNetConcatDiscriminator(args.dis_num_features, _n_cls,
                                          F.relu, args.dis_emb).to(device)
    else:
        dis = SNResNetProjectionDiscriminator(args.dis_num_features, _n_cls,
                                              F.relu).to(device)
    inception_model = inception.InceptionV3().to(
        device) if args.calc_FID else None

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type))
    # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type))
    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)
    print(' Initialized models...\n')

    if args.args_path is not None:
        print(' Load weights...\n')
        prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args(
            args.args_path, args.gen_ckpt_path, args.dis_ckpt_path)

    # tf FID
    tf_FID = build_GAN_metric(cfg=global_cfg.GAN_metric)

    class SampleFunc(object):
        def __init__(self, generator, batch, latent, gen_distribution, device):
            self.generator = generator
            self.batch = batch
            self.latent = latent
            self.gen_distribution = gen_distribution
            self.device = device
            pass

        def __call__(self, *args, **kwargs):
            with torch.no_grad():
                self.generator.eval()
                z = utils.sample_z(self.batch, self.latent, self.device,
                                   self.gen_distribution)
                pseudo_y = utils.sample_pseudo_labels(num_classes, self.batch,
                                                      self.device)
                fake_img = self.generator(z, pseudo_y)
            return fake_img

    sample_func = SampleFunc(gen,
                             batch=args.batch_size,
                             latent=args.gen_dim_z,
                             gen_distribution=args.gen_distribution,
                             device=device)

    # Training loop
    for n_iter in tqdm.tqdm(range(1, args.max_iteration + 1)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, device, num_classes,
                                                    gen)
                dis_fake = dis(fake, pseudo_y)
                if args.relativistic_loss:
                    real, y = sample_from_data(args, device, train_loader)
                    dis_real = dis(real, y)
                else:
                    dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)
                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()
                if n_iter % 10 == 0 and writer is not None:
                    writer.add_scalar('gen', _l_g, n_iter)

            fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen)
            real, y = sample_from_data(args, device, train_loader)

            dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y)
            loss_dis = dis_criterion(dis_fake, dis_real)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
            if n_iter % 10 == 0 and i == args.n_dis - 1 and writer is not None:
                cumulative_loss_dis /= args.n_dis
                writer.add_scalar('dis', cumulative_loss_dis / args.n_dis,
                                  n_iter)
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0 or n_iter == 1:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.
                format(n_iter, args.max_iteration, _l_g, cumulative_loss_dis))
            if not args.no_image:
                writer.add_image(
                    'fake',
                    torchvision.utils.make_grid(fake,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
                writer.add_image(
                    'real',
                    torchvision.utils.make_grid(real,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
            # Save previews
            utils.save_images(n_iter, n_iter // args.checkpoint_interval,
                              args.results_root, args.train_image_root, fake,
                              real)
        if n_iter % args.checkpoint_interval == 0:
            # Save checkpoints!
            utils.save_checkpoints(args, n_iter,
                                   n_iter // args.checkpoint_interval, gen,
                                   opt_gen, dis, opt_dis)
        if (n_iter % args.eval_interval == 0
                or n_iter == 1) and eval_loader is not None:
            # TODO (crcrpar): implement Ineption score, FID, and Geometry score
            # Once these criterion are prepared, val_loader will be used.
            fid_score = evaluation.evaluate(args, n_iter, gen, device,
                                            inception_model, eval_loader)
            tqdm.tqdm.write(
                '[Eval] iteration: {:07d}/{:07d}, FID: {:07f}'.format(
                    n_iter, args.max_iteration, fid_score))
            if writer is not None:
                writer.add_scalar("FID", fid_score, n_iter)
                # Project embedding weights if exists.
                embedding_layer = getattr(dis, 'l_y', None)
                if embedding_layer is not None:
                    writer.add_embedding(embedding_layer.weight.data,
                                         list(range(args.num_classes)),
                                         global_step=n_iter)
        if n_iter % global_cfg.eval_FID_every == 0 or n_iter == 1:
            FID_tf, IS_mean_tf, IS_std_tf = tf_FID(sample_func=sample_func)
            logger.info(
                f'IS_mean_tf:{IS_mean_tf:.3f} +- {IS_std_tf:.3f}\n\tFID_tf: {FID_tf:.3f}'
            )
            if not math.isnan(IS_mean_tf):
                summary_d = {}
                summary_d['FID_tf'] = FID_tf
                summary_d['IS_mean_tf'] = IS_mean_tf
                summary_d['IS_std_tf'] = IS_std_tf
                summary_dict2txtfig(summary_d,
                                    prefix='train',
                                    step=n_iter,
                                    textlogger=global_textlogger)
            gen.train()
    if args.test:
        shutil.rmtree(args.results_root)
Example #7
0
def main(args):
    if use_cuda:
        torch.cuda.set_device(0)
        torch.cuda.manual_seed(conf.seed)
        cudnn.benchmark = True

    print('===> Building model')
    model = build_model(conf.model)
    NetG = model.NetD()
    NetD = model.NetG()
    print('===> Number of NetG params: {}'.format(
        sum([p.data.nelement() for p in NetG.parameters()])))
    print('===> Number of NetD params: {}'.format(
        sum([p.data.nelement() for p in NetD.parameters()])))

    if use_cuda:
        NetG = NetG.cuda()
        NetD = NetD.cuda()
    # setup optimizer
    decay = conf.decay
    optimizerG = optim.Adam(NetG.parameters(),
                            lr=conf.learning_rate,
                            betas=(conf.beta1, 0.999))
    optimizerD = optim.Adam(NetD.parameters(),
                            lr=conf.learning_rate_netd,
                            betas=(conf.beta1, 0.999))

    # load data
    training_set = DataLoader(dataset=loader_npy.loader_npy(
        image_path=args.train_data_path,
        mask_path=args.train_label_path,
        mode='train'),
                              num_workers=conf.threads,
                              batch_size=conf.batch_size,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

    validation_set = DataLoader(dataset=loader_npy.loader_npy(
        image_path=args.val_data_path,
        mask_path=args.val_label_path,
        mode='val'),
                                num_workers=conf.threads,
                                batch_size=conf.batch_size,
                                shuffle=True,
                                pin_memory=True,
                                drop_last=True)

    start_i = 1
    total_i = conf.epochs

    if conf.from_scratch:
        pass
    else:
        cp = utils.get_resume_path('s')
        pretrained_dict = torch.load(cp)
        model_dict = NetG.state_dict()
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        NetG.load_state_dict(model_dict)

        cp_name = os.path.basename(cp)

        cp2 = utils.get_resume_path('c')
        print('#####  resume_path(c):', cp2)
        NetD.load_state_dict(torch.load(cp2))
        print('---> Loading checkpoint {}...'.format(cp_name))
        start_i = int(cp_name.split('_')[-1].split('.')[0]) + 1

    max_iou = 0
    print('===> Begin training at epoch {}'.format(start_i))
    for epoch in range(start_i, total_i + 1):
        print("---------------eppch[{}]-------------------".format(epoch))
        train(NetG, NetD, optimizerG, optimizerD, training_set, epoch)

        if epoch % 2 == 0:
            val(NetG, validation_set, epoch, max_iou)

            utils.save_checkpoints(NetG, epoch, 's')
            utils.save_checkpoints(NetD, epoch, 'c')

        if epoch % 20 == 0:
            conf.learning_rate = conf.learning_rate * decay
            if conf.learning_rate <= 0.00000001:
                conf.learning_rate = 0.00000001

            conf.learning_rate_netd = conf.learning_rate_netd * decay
            if conf.learning_rate_netd <= 0.00000001:
                conf.learning_rate_netd = 0.00000001

            print('Learning Rate: {:.6f}'.format(conf.learning_rate))
            optimizerG = optim.Adam(NetG.parameters(),
                                    lr=conf.learning_rate,
                                    betas=(conf.beta1, 0.999))
            optimizerD = optim.Adam(NetD.parameters(),
                                    lr=conf.learning_rate_netd,
                                    betas=(conf.beta1, 0.999))
        disc_loss.backward()
        optimizer.step()

        disc_sum += disc_inv_loss.item()
    pb.update()
    train_summary_writer.add_scalar("disc_loss",
                                    disc_sum / args.samples_per_epoch,
                                    global_step=global_itr)

    if global_itr % args.checkpoint_freq == 0:
        checkpoint = {
            "model": dvml_model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "global_itr": global_itr
        }
        utils.save_checkpoints(checkpoint,
                               os.path.join(args.log_dir,
                                            "model_{}.pth".format(global_itr)),
                               max_chkps=args.max_checkpoints)

checkpoint = {
    "model": dvml_model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "global_itr": args.num_epochs - 1
}
utils.save_checkpoints(
    checkpoint,
    os.path.join(args.log_dir, "model_{}.pth".format(args.num_epochs - 1)),
    max_chkps=args.max_checkpoints)
train_summary_writer.close()
Example #9
0
def train(dataloader):
    """ Train the model on `num_steps` batches
    Args:
        dataloader : (DataLoader) a torch.utils.data.DataLoader object that fetches training data
        num_steps : (int) # of batches to train on, each of size args.batch_size
    """

    # Define Generator, Discriminator
    G = Generator(out_channel=ch).to(device) # MNIST channel: 1, CIFAR-10 channel: 3
    D = Discriminator(in_channel=ch).to(device)

    # adversarial loss
    loss_fn = nn.BCELoss()

    # Initialize weights
    G.apply(init_weights)
    D.apply(init_weights)

    optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
    optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))

    # Establish convention for real and fake labels during training
    real_label = 1.
    fake_label = 0.

    # -----Training----- #
    for epoch in range(epochs):
        # For each batch in the dataloader

        for i, data in enumerate(dataloader, 0):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            ###########################
            ## Train with all-real batch
            D.zero_grad()
            # Format batch
            real_cpu = data[0].to(device) # load image batch size
            b_size = real_cpu.size(0) # batch size
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device, requires_grad=False) # real batch

            # Forward pass **real batch** through D
            output = D(real_cpu).view(-1)
            # Calculate loss on all-real batch
            errD_real = loss_fn(output, label)
            # Calculate gradients for D in backward pass
            errD_real.backward()

            ## Train with **all-fake** batch
            # Generate noise batch of latent vectors
            noise = torch.randn(b_size, latent_dim, 1, 1, device=device)
            # Generate fake image batch with G
            fake = G(noise)
            label.fill_(fake_label) # fake batch

            # Classify all fake batch with D
            output = D(fake.detach()).view(-1)
            # Calculate D's loss on the all-fake batch
            errD_fake = loss_fn(output, label)
            # Calculate the gradients for this batch
            errD_fake.backward()
            # Add the gradients from the all-real and all-fake batches
            errD = errD_real + errD_fake
            # Update D
            optimizer_D.step()

            ############################
            # (2) Update G network: maximize log(D(G(z)))
            ###########################
            G.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost

            # Since we just updated D, perform another forward pass of all-fake batch through D
            output = D(fake).view(-1)
            # Calculate G's loss based on this output
            errG = loss_fn(output, label)
            # Calculate gradients for G
            errG.backward()
            # Update G
            optimizer_G.step()

            # Save fake images generated by Generator
            batches_done = epoch * len(dataloader) + i
            if batches_done % 400 == 0:
                save_image(fake.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

        print(f"[Epoch {epoch + 1}/{epochs}] [D loss: {errD.item():.4f}] [G loss: {errG.item():.4f}]")

        # Save Generator model's parameters
        save_checkpoints(
            {'epoch': i + 1,
             'state_dict': G.state_dict(),
             'optim_dict': optimizer_G.state_dict()},
            checkpoint='./ckpt/',
            is_G=True
        )

        # Save Discriminator model's parameters
        save_checkpoints(
            {'epoch': i + 1,
             'state_dict': D.state_dict(),
             'optim_dict': optimizer_D.state_dict()},
            checkpoint='./ckpt/',
            is_G=False
        )
def main():
    args = get_args()
    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    device = torch.device('cuda')
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    def _rescale(img):
        return img * 2.0 - 1.0

    def _noise_adder(img):
        return torch.empty_like(img, dtype=img.dtype).uniform_(0.0,
                                                               1 / 128.0) + img

    # dataset
    train_dataset = datasets.ImageFolder(
        os.path.join(args.data_root, 'train'),
        transforms.Compose([
            transforms.ToTensor(),
            _rescale,
            _noise_adder,
        ]))
    train_loader = iter(
        data.DataLoader(train_dataset,
                        args.batch_size,
                        sampler=InfiniteSamplerWrapper(train_dataset),
                        num_workers=args.num_workers,
                        pin_memory=True))
    if args.calc_FID:
        eval_dataset = datasets.ImageFolder(
            os.path.join(args.data_root, 'val'),
            transforms.Compose([
                transforms.ToTensor(),
                _rescale,
            ]))
        eval_loader = iter(
            data.DataLoader(eval_dataset,
                            args.batch_size,
                            sampler=InfiniteSamplerWrapper(eval_dataset),
                            num_workers=args.num_workers,
                            pin_memory=True))
    else:
        eval_loader = None
    num_classes = len(train_dataset.classes)
    print(' prepared datasets...')
    print(' Number of training images: {}'.format(len(train_dataset)))
    # Prepare directories.
    args.num_classes = num_classes
    args, writer = prepare_results_dir(args)
    # initialize models.
    _n_cls = num_classes if args.cGAN else 0
    gen = ResNetGenerator(args.gen_num_features,
                          args.gen_dim_z,
                          args.gen_bottom_width,
                          activation=F.relu,
                          num_classes=_n_cls,
                          distribution=args.gen_distribution).to(device)
    if args.dis_arch_concat:
        dis = SNResNetConcatDiscriminator(args.dis_num_features, _n_cls,
                                          F.relu, args.dis_emb).to(device)
    else:
        dis = SNResNetProjectionDiscriminator(args.dis_num_features, _n_cls,
                                              F.relu).to(device)
    inception_model = inception.InceptionV3().to(
        device) if args.calc_FID else None

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    # gen_criterion = getattr(L, 'gen_{}'.format(args.loss_type))
    # dis_criterion = getattr(L, 'dis_{}'.format(args.loss_type))
    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)
    print(' Initialized models...\n')

    if args.args_path is not None:
        print(' Load weights...\n')
        prev_args, gen, opt_gen, dis, opt_dis = utils.resume_from_args(
            args.args_path, args.gen_ckpt_path, args.dis_ckpt_path)

    # Training loop
    for n_iter in tqdm.tqdm(range(1, args.max_iteration + 1)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, device, num_classes,
                                                    gen)
                dis_fake = dis(fake, pseudo_y)
                if args.relativistic_loss:
                    real, y = sample_from_data(args, device, train_loader)
                    dis_real = dis(real, y)
                else:
                    dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)
                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()
                if n_iter % 10 == 0 and writer is not None:
                    writer.add_scalar('gen', _l_g, n_iter)

            fake, pseudo_y, _ = sample_from_gen(args, device, num_classes, gen)
            real, y = sample_from_data(args, device, train_loader)

            dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y)
            loss_dis = dis_criterion(dis_fake, dis_real)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
            if n_iter % 10 == 0 and i == args.n_dis - 1 and writer is not None:
                cumulative_loss_dis /= args.n_dis
                writer.add_scalar('dis', cumulative_loss_dis / args.n_dis,
                                  n_iter)
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.
                format(n_iter, args.max_iteration, _l_g, cumulative_loss_dis))
            if not args.no_image:
                writer.add_image(
                    'fake',
                    torchvision.utils.make_grid(fake,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
                writer.add_image(
                    'real',
                    torchvision.utils.make_grid(real,
                                                nrow=4,
                                                normalize=True,
                                                scale_each=True))
            # Save previews
            utils.save_images(n_iter, n_iter // args.checkpoint_interval,
                              args.results_root, args.train_image_root, fake,
                              real)
        if n_iter % args.checkpoint_interval == 0:
            # Save checkpoints!
            utils.save_checkpoints(args, n_iter,
                                   n_iter // args.checkpoint_interval, gen,
                                   opt_gen, dis, opt_dis)
        if n_iter % args.eval_interval == 0:
            # TODO (crcrpar): implement Ineption score, FID, and Geometry score
            # Once these criterion are prepared, val_loader will be used.
            fid_score = evaluation.evaluate(args, n_iter, gen, device,
                                            inception_model, eval_loader)
            tqdm.tqdm.write(
                '[Eval] iteration: {:07d}/{:07d}, FID: {:07f}'.format(
                    n_iter, args.max_iteration, fid_score))
            if writer is not None:
                writer.add_scalar("FID", fid_score, n_iter)
                # Project embedding weights if exists.
                embedding_layer = getattr(dis, 'l_y', None)
                if embedding_layer is not None:
                    writer.add_embedding(embedding_layer.weight.data,
                                         list(range(args.num_classes)),
                                         global_step=n_iter)
    if args.test:
        shutil.rmtree(args.results_root)
Example #11
0
def main():
    args = get_args()
    weight_path, img_path = directory_path(args)

    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    # dataloading
    train_loader, s_dlen, _n_cls = data_loader2(args)

    fixed_z = torch.randn(200, 10, 128)
    fixed_img_list, fixed_label_list = pick_fixed_img(args, train_loader, 200)

    # initialize model
    gen, dis = select_model(args, _n_cls)

    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    criterion = nn.CrossEntropyLoss()

    # Training loop
    for n_iter in tqdm.tqdm(range(0, args.max_iteration)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
                dis_fake, dis_mi, dis_c = dis(fake, pseudo_y)
                dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)

                ##################################################
                loss_mi = criterion(dis_mi, pseudo_y)
                loss_c = criterion(dis_c, pseudo_y)

                loss_gen = loss_gen + args.lambda_c * (loss_c - loss_mi)
                ##################################################

                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()

            fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
            real, y = sample_from_data(args, dev, train_loader)

            dis_fake, dis_fake_mi, dis_fake_c = dis(fake, pseudo_y)
            dis_real, dis_real_mi, dis_real_c = dis(real, y)

            ######################################################
            loss_dis_mi = criterion(dis_fake_mi, pseudo_y)
            loss_dis_c = criterion(dis_real_c, y)
            ######################################################

            loss_dis = dis_criterion(dis_fake, dis_real)
            loss_dis = loss_dis + args.lambda_c * (loss_dis_mi + loss_dis_c)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'
                ' loss mi {:05f}, loss c {:05f}'.format(
                    n_iter, args.max_iteration, _l_g, cumulative_loss_dis,
                    args.lambda_c * loss_dis_mi, args.lambda_c * loss_dis_c))

        if n_iter % args.checkpoint_interval == 0:
            #Save checkpoints!
            utils.save_checkpoints(args, n_iter, gen, opt_gen, dis, opt_dis,
                                   weight_path)
            if args.dataset == "omniglot":
                utils.save_img(fixed_img_list,
                               fixed_label_list,
                               fixed_z,
                               gen,
                               32,
                               28,
                               img_path,
                               n_iter,
                               device=dev)
            elif args.dataset == "vgg" or args.dataset == "animal":
                utils.save_img(fixed_img_list,
                               fixed_label_list,
                               fixed_z,
                               gen,
                               84,
                               64,
                               img_path,
                               n_iter,
                               device=dev)
            elif args.dataset == "cub":
                utils.save_img(fixed_img_list,
                               fixed_label_list,
                               fixed_z,
                               gen,
                               72,
                               64,
                               img_path,
                               n_iter,
                               device=dev)
            else:
                raise Exception("Enter model omniglot or vgg or animal or cub")

    if args.test:
        shutil.rmtree(args.results_root)
Example #12
0
    pred = model(inp)

    optimizer.zero_grad()
    loss = criterion(pred, target)
    loss_avg += loss.data[0]
    loss.backward()

    for p in model.parameters():
        p.grad.data.clamp_(max=args.clip)

    optimizer.step()
    global_step += 1

    if global_step % args.print_freq == 0:
        print("at step {} loss {}".format(global_step,
                                          loss_avg / args.print_freq))
        writer.scalar_summary("train_bce", loss_avg / args.print_freq,
                              global_step)
        loss_avg = 0

    if args.save_freq and global_step % args.save_freq == 0:
        utils.save_checkpoints(
            {
                'global_step': global_step,
                'args': args,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict()
            },
            os.path.join(args.model_dir,
                         args.model_prefix + str(global_step) + '.pt'))
Example #13
0
def main():
    args = get_args()
    weight_path, img_path = directory_path(args)

    # CUDA setting
    if not torch.cuda.is_available():
        raise ValueError("Should buy GPU!")

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    torch.backends.cudnn.benchmark = True

    # dataloading
    train_loader, s_dlen, _n_cls = data_loader2(args)

    fixed_z = torch.randn(200, 10, 128)
    fixed_img_list, fixed_label_list = pick_fixed_img(args, train_loader, 200)

    # initialize model
    gen, dis, start = select_model(args, _n_cls)
    
    gen, dis = gen.cuda(), dis.cuda()
    
    opt_gen = optim.Adam(gen.parameters(), args.lr, (args.beta1, args.beta2))
    opt_dis = optim.Adam(dis.parameters(), args.lr, (args.beta1, args.beta2))

    gen_criterion = L.GenLoss(args.loss_type, args.relativistic_loss)
    dis_criterion = L.DisLoss(args.loss_type, args.relativistic_loss)

    # Training loop
    for n_iter in tqdm.tqdm(range(start, args.max_iteration)):

        if n_iter >= args.lr_decay_start:
            decay_lr(opt_gen, args.max_iteration, args.lr_decay_start, args.lr)
            decay_lr(opt_dis, args.max_iteration, args.lr_decay_start, args.lr)

        # ==================== Beginning of 1 iteration. ====================
        _l_g = .0
        cumulative_loss_dis = .0
        for i in range(args.n_dis):
            if i == 0:
                fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
                fake = DiffAugment(fake, policy=policy)
                dis_fake = dis(fake, pseudo_y)
                dis_real = None

                loss_gen = gen_criterion(dis_fake, dis_real)
                gen.zero_grad()
                loss_gen.backward()
                opt_gen.step()
                _l_g += loss_gen.item()

            fake, pseudo_y, _ = sample_from_gen(args, dev, _n_cls, gen)
            real, y = sample_from_data(args, dev, train_loader)

            fake = DiffAugment(fake, policy=policy)
            real = DiffAugment(real, policy=policy)

            dis_fake, dis_real = dis(fake, pseudo_y), dis(real, y)
            loss_dis = dis_criterion(dis_fake, dis_real)

            dis.zero_grad()
            loss_dis.backward()
            opt_dis.step()

            cumulative_loss_dis += loss_dis.item()
        # ==================== End of 1 iteration. ====================

        if n_iter % args.log_interval == 0:
            tqdm.tqdm.write(
                'iteration: {:07d}/{:07d}, loss gen: {:05f}, loss dis {:05f}'.format(
                    n_iter, args.max_iteration, _l_g, cumulative_loss_dis))

        if n_iter % args.checkpoint_interval == 0:
            #Save checkpoints!
            utils.save_checkpoints(args, n_iter, gen, opt_gen, dis, opt_dis, weight_path)
            utils.save_img(fixed_img_list, fixed_label_list, fixed_z, gen,
                           32, 28, img_path, n_iter, device=dev)
    if args.test:
        shutil.rmtree(args.results_root)
Example #14
0
def train(cfg, logger, vis):
    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        split=cfg["data"]["train_split"],
        patch_size=cfg['data']['patch_size'],
        augmentation=cfg['data']['aug_data']
    )

    v_loader = data_loader(
        data_path,
        split=cfg["data"]["val_split"],
    )

    trainloader = DataLoader(
        t_loader,
        batch_size=cfg["batch_size"],
        num_workers=cfg["n_workers"],
        shuffle=True,
    )

    valloader = DataLoader(
        v_loader, batch_size=cfg["batch_size"], num_workers=cfg["n_workers"]
    )

    # Setup model, optimizer and loss function
    model_cls = get_model(cfg['model'])
    model = model_cls(cfg).to(device)

    optimizer_cls = get_optimizer(cfg)
    optimizer_params = {k: v for k, v in cfg["optimizer"].items() if k != "name"}
    optimizer = optimizer_cls(model.parameters(), **optimizer_params)

    scheduler = MultiStepLR(optimizer, milestones=[15000, 17500], gamma=0.1)

    crit = get_critical(cfg['critical'])().to(device)
    ssim = SSIM().to(device)

    step = 0

    if cfg['resume'] is not None:
        pass

    while step < cfg['max_iters']:
        scheduler.step()
        model.train()

        if cfg['model'] == 'rescan':
            O, B, prediciton = inference_rescan(model=model, optimizer=optimizer, dataloader=trainloader,
                                                critical=crit, ssim=ssim,
                                                step=step, vis=vis)
        if cfg['model'] == 'did_mdn':
            O, B, prediciton, label = inference_didmdn(model=model, optimizer=optimizer, dataloader=trainloader,
                                                       critical=crit, ssim=ssim,
                                                       step=step, vis=vis)

        if step % 10 == 0:
            model.eval()
            if cfg['model'] == 'rescan':
                O, B, prediciton_v = inference_rescan(model=model, optimizer=optimizer, dataloader=valloader,
                                                      critical=crit, ssim=ssim,
                                                      step=step, vis=vis)
            if cfg['model'] == 'did_mdn':
                O, B, prediciton, label = inference_didmdn(model=model, optimizer=optimizer,
                                                           dataloader=valloader,
                                                           critical=crit, ssim=ssim,
                                                           step=step, vis=vis)

        if step % int(cfg['save_steps'] / 16) == 0:
            save_checkpoints(model, step, optimizer, cfg['checkpoint_dir'], 'latest')
        if step % int(cfg['save_steps'] / 2) == 0:
            save_image('train', [O.cpu(), prediciton.cpu(), B.cpu()], cfg['checkpoint_dir'], step, cfg['batch_size'])
            if step % 10 == 0:
                save_image('val', [O.cpu(), prediciton.cpu(), B.cpu()], cfg['checkpoint_dir'], step, cfg['batch_size'])
            logger.info('save image as step_%d' % step)
        if step % cfg['save_steps'] == 0:
            save_checkpoints(model=model,
                             step=step,
                             optim=optimizer,
                             model_dir=cfg['checkpoint_dir'],
                             name='{}_step_{}'.format(cfg['model'] + cfg['data']['dataset'], step))
            logger.info('save model as step_%d' % step)
        step += 1
Example #15
0
def train_gan(cfg, logger, vis):
    # Setup seeds
    torch.manual_seed(cfg.get("seed", 1337))
    torch.cuda.manual_seed(cfg.get("seed", 1337))
    np.random.seed(cfg.get("seed", 1337))
    random.seed(cfg.get("seed", 1337))

    # Setup Dataloader
    data_loader = get_loader(cfg["data"]["dataset"])
    data_path = cfg["data"]["path"]

    t_loader = data_loader(
        data_path,
        split=cfg["data"]["train_split"],
        patch_size=cfg['data']['patch_size'],
        augmentation=cfg['data']['aug_data']
    )

    train_loader = DataLoader(
        t_loader,
        batch_size=cfg["batch_size"],
        num_workers=cfg["n_workers"],
        shuffle=True,
    )

    # custom weights initialization called on netG and netD
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            m.weight.data.normal_(0.0, 0.02)
            m.bias.data.fill_(0)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.normal_(1.0, 0.02)
            m.bias.data.fill_(0)

    ndf = cfg['ndf']
    ngf = cfg['ngf']
    nc = 3

    netD_cls = get_model(cfg['netd'])
    netG_cls = get_model(cfg['netg'])

    netD = netD_cls(nc, cfg['output_nc'], ndf).to(device)
    netG = netG_cls(cfg['input_nc'], cfg['output_nc'], ngf).to(device)

    netG.apply(weights_init)
    netD.apply(weights_init)
    logger.info(netD)
    logger.info(netG)

    ###########   LOSS & OPTIMIZER   ##########
    criterion = torch.nn.BCELoss()
    criterionL1 = torch.nn.L1Loss()

    optimizerD = torch.optim.Adam(netD.parameters(), lr=cfg['optimizer']['lr'],
                                  betas=(cfg['optimizer']['beta1'], 0.999))
    optimizerG = torch.optim.Adam(netG.parameters(), lr=cfg['optimizer']['lr'],
                                  betas=(cfg['optimizer']['beta1'], 0.999))

    ###########   GLOBAL VARIABLES   ###########
    input_nc = cfg['input_nc']
    output_nc = cfg['output_nc']
    fineSize = cfg['data']['patch_size']

    real_A = Variable(torch.FloatTensor(cfg['batch_size'], input_nc, fineSize, fineSize), requires_grad=False).to(
        device)
    real_B = Variable(torch.FloatTensor(cfg['batch_size'], output_nc, fineSize, fineSize), requires_grad=False).to(
        device)
    label = Variable(torch.FloatTensor(cfg['batch_size']), requires_grad=False).to(device)

    real_label = 1
    fake_label = 0

    ########### Training   ###########
    netD.train()
    netG.train()
    for epoch in range(1, cfg['max_iters'] + 1):
        for i, image in enumerate(train_loader):
            ########### fDx ###########
            netD.zero_grad()
            if cfg['direction'] == 'OtoB':
                imgA = image[1]
                imgB = image[0]
            else:
                imgA = image[0]
                imgB = image[1]

            # train with real data
            real_A.data.resize_(imgA.size()).copy_(imgA)
            real_B.data.resize_(imgB.size()).copy_(imgB)
            real_AB = torch.cat((real_A, real_B), 1)

            output = netD(real_AB)
            label.data.resize_(output.size())
            label.data.fill_(real_label)
            errD_real = criterion(output, label)
            errD_real.backward()

            # train with fake
            fake_B = netG(real_A)
            label.data.fill_(fake_label)

            fake_AB = torch.cat((real_A, fake_B), 1)
            output = netD(fake_AB.detach())
            errD_fake = criterion(output, label)
            errD_fake.backward()

            errD = (errD_fake + errD_real) / 2
            optimizerD.step()

            ########### fGx ###########
            netG.zero_grad()
            label.data.fill_(real_label)
            output = netD(fake_AB)
            errGAN = criterion(output, label)
            errL1 = criterionL1(fake_B, real_B)
            errG = errGAN + cfg['lamb'] * errL1

            errG.backward()
            optimizerG.step()

            ########### Logging ##########
            if i % 50 == 0:
                logger.info('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f Loss_L1: %.4f'
                            % (epoch, cfg['max_iters'], i, len(train_loader),
                               errD.item(), errGAN.item(), errL1.item()))

            if cfg['vis']['use'] and (i % 50 == 0):
                fake_B = netG(real_A)
                vis.images(real_A.data.cpu().numpy(), win='real_A')
                vis.images(fake_B.detach().cpu().numpy(), win='fake_B')
                vis.images(real_B.data.cpu().numpy(), win='real_B')
                vis.plot('error_d', errD.item())
                vis.plot('error_g', errGAN.item())
                vis.plot('error_L1', errL1.item())

        if epoch % 20 == 0:
            save_image(
                name='train',
                img_lists=[real_A.data.cpu(), fake_B.data.cpu(), real_B.data.cpu()],
                path='%s/fake_samples_epoch_%03d.png' % (cfg['checkpoint_dir'], epoch),
                step=epoch,
                batch_size=cfg['batch_size']
            )
            save_checkpoints(model=netG,
                             step=epoch,
                             optim=optimizerG,
                             model_dir=cfg['checkpoint_dir'],
                             name='{}_step_{}'.format(cfg['netg'] + cfg['data']['dataset'], epoch))
            save_checkpoints(model=netD,
                             step=epoch,
                             optim=optimizerD,
                             model_dir=cfg['checkpoint_dir'],
                             name='{}_step_{}'.format(cfg['netd'] + cfg['data']['dataset'], epoch))
                          reduction='sum')  # ignore <pad> token
        avg_loss += loss.item()
        tot_sz += src.size(1)  # L x N x F
    print("Perplexity: {:12.6}".format(avg_loss / tot_sz))
    summary_writer.add_scalar("perplexity", avg_loss / tot_sz, niter)
    return avg_loss


if __name__ == "__main__":
    conf = Config(args.config_file)
    train_data, valid_data, test_data = create_dateset_from_config(conf)
    model, vocab = create_model_from_config(conf)
    model, optimizer = create_optimizer_from_config(conf, model)
    writer = SummaryWriter(args.summary_folder)

    if args.mode == "train":
        print("Start training mode....")
        niter = 0
        for nepoch in range(conf["train_config"]["nepochs"]):
            # update iteration number
            niter = train(model, optimizer, niter, train_data, valid_data,
                          test_data, writer)
            if (niter + 1) % conf["train_config"]["nsave"]:
                save_checkpoints(model, "model.pt", optimizer, niter)
            # do eval at the end of each epoch
            ppl = predict(model, valid_data, nepoch)
    elif args.mode == "predict":
        ppl = predict(model, test_data, 0)
    else:
        print("Invalid mode argument: {}, exiting...".format(args.mode))