Beispiel #1
0
def main(config):
    model = load_model(config)
    your_pic = config.your_pic
    celeb_pic = config.celeb_pic
    try:
        your_pic = Image.open(your_pic)
        celeb_pic = Image.open(celeb_pic)
    except Exception:
        print("Provide correct paths!")
        return

    im1, m1, im2, m2 = DataLoader(
        transform(your_pic, celeb_pic, config.image_size, model,
                  config.device))
    im1, m1, im2, m2 = map(lambda x: x.repeat(2, 1, 1, 1), [im1, m1, im2, m2])

    # Initialize
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)

    try:
        resnet.load_state_dict(
            torch.load(os.path.join(
                config.checkpoints,
                'epoch_%d_%s.pth' % (config.inf_epoch, 'resnet')),
                       map_location=config.device))
        generator.load_state_dict(
            torch.load(os.path.join(
                config.checkpoints,
                'epoch_%d_%s.pth' % (config.inf_epoch, 'generator')),
                       map_location=config.device))
    except OSError:
        print('Check if your pretrained weight is in the right place.')

    with torch.no_grad():
        resnet.eval()
        generator.eval()

        z = resnet(im2 * m2)
        fake_im = generator(im1, im2, z)
        images = [im1[0], im2[0], fake_im[0]]
        titles = ['Your picture', 'Celebrity picture', 'Synthesized picture']

        fig, axes = plt.subplots(1, len(titles))
        for i in range(len(images)):
            im = images[i]
            im = im.data.cpu().numpy().transpose(1, 2, 0)
            im = (im + 1) / 2
            axes[i].imshow(im)
            axes[i].axis('off')
            axes[i].set_title(titles[i])

        plt.show()
Beispiel #2
0
def inference(config):
    hair_model, skin_model = load_model(config)

    #train_loader, val_loader = get_loaders(hair_model, skin_model, config)

    try:
        your_pic = Image.open(config.your_pic)
        celeb_pic = Image.open(config.celeb_pic)

    except:
        return

    your_pic, your_pic_mask, celeb_pic, celeb_pic_mask = DataLoader(
        transform(your_pic, celeb_pic, config.image_size, hair_model,
                  skin_model, config.device))

    # Initialize
    vgg = Vgg16().to(config.device)
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)
    # discriminator = Discriminator().to(config.device)

    try:
        resnet.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (20, 'resnet'))))
        generator.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (20, 'generator'))))
    except OSError:
        print('Check if your pretrained weight is in the right place.')

    z1 = resnet(your_pic * your_pic_mask)  #skin
    z2 = resnet(celeb_pic * celeb_pic_mask)  #hair
    fake_im = generator(your_pic, z1, z2)  # z1 is skin, z2 is hair

    images = [your_pic[0], celeb_pic[0], fake_im[0]]
    titles = ['Your picture', 'Celebrity picture', 'Synthesized picture']

    fig, axes = plt.subplots(1, len(titles))
    for i in range(len(images)):
        im = images[i]
        im = im.data.cpu().numpy().transpose(1, 2, 0)
        im = (im + 1) / 2
        axes[i].imshow(im)
        axes[i].axis('off')
        axes[i].set_title(titles[i])

    plt.show()
Beispiel #3
0
def get_network(model, channel, num_classes, im_size=(32, 32)):
    torch.random.manual_seed(int(time.time() * 1000) % 100000)
    net_width, net_depth, net_act, net_norm, net_pooling = get_default_convnet_setting(
    )

    if model == 'MLP':
        net = MLP(channel=channel, num_classes=num_classes)
    elif model == 'ConvNet':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling,
                      im_size=im_size)
    elif model == 'LeNet':
        net = LeNet(channel=channel, num_classes=num_classes)
    elif model == 'AlexNet':
        net = AlexNet(channel=channel, num_classes=num_classes)
    elif model == 'VGG11':
        net = VGG11(channel=channel, num_classes=num_classes)
    elif model == 'VGG11BN':
        net = VGG11BN(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18':
        net = ResNet18(channel=channel, num_classes=num_classes)
    elif model == 'ResNet18BN_AP':
        net = ResNet18BN_AP(channel=channel, num_classes=num_classes)

    elif model == 'ConvNetD1':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=1,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetD2':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=2,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetD3':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=3,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetD4':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=4,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)

    elif model == 'ConvNetW32':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=32,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetW64':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=64,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetW128':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=128,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetW256':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=256,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling=net_pooling)

    elif model == 'ConvNetAS':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act='sigmoid',
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetAR':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act='relu',
                      net_norm=net_norm,
                      net_pooling=net_pooling)
    elif model == 'ConvNetAL':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act='leakyrelu',
                      net_norm=net_norm,
                      net_pooling=net_pooling)

    elif model == 'ConvNetNN':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm='none',
                      net_pooling=net_pooling)
    elif model == 'ConvNetBN':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm='batchnorm',
                      net_pooling=net_pooling)
    elif model == 'ConvNetLN':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm='layernorm',
                      net_pooling=net_pooling)
    elif model == 'ConvNetIN':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm='instancenorm',
                      net_pooling=net_pooling)
    elif model == 'ConvNetGN':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm='groupnorm',
                      net_pooling=net_pooling)

    elif model == 'ConvNetNP':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling='none')
    elif model == 'ConvNetMP':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling='maxpooling')
    elif model == 'ConvNetAP':
        net = ConvNet(channel=channel,
                      num_classes=num_classes,
                      net_width=net_width,
                      net_depth=net_depth,
                      net_act=net_act,
                      net_norm=net_norm,
                      net_pooling='avgpooling')

    else:
        net = None
        exit('DC error: unknown model')

    gpu_num = torch.cuda.device_count()
    if gpu_num > 0:
        device = 'cuda'
        if gpu_num > 1:
            net = nn.DataParallel(net)
    else:
        device = 'cpu'
    net = net.to(device)

    return net
Beispiel #4
0
                             num_workers=config['num_workers'])
    valloader = DataLoader(valset,
                           config['test_size'],
                           shuffle=False,
                           num_workers=config['num_workers'])

    print(f"Train data: {len(trainset)} | Val data: {len(valset)}")
    print(f"Number of batch: {len(trainloader)}/epoch")

    ########################################
    # Model
    ########################################
    num_class = len(class2idx)
    input_size = config['new_size']
    if config['model'] == 'ResNet18':
        model = ResNet18(input_size, num_class, config['pretrained'],
                         config['freeze'])
    elif config['model'] == 'HaHaNet':
        model = HaHaNet(input_size, num_class, config['pretrained'],
                        config['freeze'])
    else:
        raise f"Model {config['model']} is not support."

    # Drop model into GPU
    model.cuda()

    ########################################
    # Criterion and Optimizer
    ########################################
    def criterion(predicts, targets):
        """ predicts: (N, C), targets: (N) """
        loss = nn.CrossEntropyLoss(predicts, targets)
Beispiel #5
0
def main(config):
    model = load_model(config)
    train_loader, val_loader = get_loaders(model, config)

    # Make dirs
    if not os.path.exists(config.checkpoints):
        os.makedirs(config.checkpoints, exist_ok=True)
    if not os.path.exists(config.save_path):
        os.makedirs(config.save_path, exist_ok=True)

    # Loss Functions
    criterion_GAN = mse_loss

    # Calculate output of image discriminator (PatchGAN)
    patch = (1, config.image_size // 2**4, config.image_size // 2**4)

    # Initialize
    vgg = Vgg16().to(config.device)
    resnet = ResNet18(requires_grad=True, pretrained=True).to(config.device)
    generator = GeneratorUNet().to(config.device)
    discriminator = Discriminator().to(config.device)

    if config.epoch != 0:
        # Load pretrained models
        resnet.load_state_dict(
            torch.load(
                os.path.join(config.checkpoints, 'epoch_%d_%s.pth' %
                             (config.epoch - 1, 'resnet'))))
        generator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'generator'))))
        discriminator.load_state_dict(
            torch.load(
                os.path.join(
                    config.checkpoints,
                    'epoch_%d_%s.pth' % (config.epoch - 1, 'discriminator'))))
    else:
        # Initialize weights
        # resnet.apply(weights_init_normal)
        generator.apply(weights_init_normal)
        discriminator.apply(weights_init_normal)

    # Optimizers
    optimizer_resnet = torch.optim.Adam(resnet.parameters(),
                                        lr=config.lr,
                                        betas=(config.b1, config.b2))
    optimizer_G = torch.optim.Adam(generator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(),
                                   lr=config.lr,
                                   betas=(config.b1, config.b2))

    # ----------
    #  Training
    # ----------

    resnet.train()
    generator.train()
    discriminator.train()
    for epoch in range(config.epoch, config.n_epochs):
        for i, (im1, m1, im2, m2) in enumerate(train_loader):
            assert im1.size(0) == im2.size(0)
            valid = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                             requires_grad=False)
            fake = Variable(torch.Tensor(np.ones(
                (im1.size(0), *patch))).to(config.device),
                            requires_grad=False)

            # ------------------
            #  Train Generators
            # ------------------

            optimizer_resnet.zero_grad()
            optimizer_G.zero_grad()

            # GAN loss
            z = resnet(im2 * m2)
            if epoch < config.gan_epochs:
                fake_im = generator(im1 * (1 - m1), im2 * m2, z)
            else:
                fake_im = generator(im1, im2, z)
            if epoch < config.gan_epochs:
                pred_fake = discriminator(fake_im, im2)
                gan_loss = config.lambda_gan * criterion_GAN(pred_fake, valid)
            else:
                gan_loss = torch.Tensor([0]).to(config.device)

            # Hair, Face loss
            fake_m2 = torch.argmax(model(fake_im),
                                   1).unsqueeze(1).type(torch.uint8).repeat(
                                       1, 3, 1, 1).to(config.device)
            if 0.5 * torch.sum(m1) <= torch.sum(
                    fake_m2) <= 1.5 * torch.sum(m1):
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * fake_m2, im2 * m2, vgg) + calc_content_loss(
                        fake_im * fake_m2, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            else:
                hair_loss = config.lambda_style * calc_style_loss(
                    fake_im * m1, im2 * m2, vgg) + calc_content_loss(
                        fake_im * m1, im2 * m2, vgg)
                face_loss = calc_content_loss(fake_im, im1, vgg)
            hair_loss *= config.lambda_hair
            face_loss *= config.lambda_face

            # Total loss
            loss = gan_loss + hair_loss + face_loss

            loss.backward()
            optimizer_resnet.step()
            optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------

            if epoch < config.gan_epochs:
                optimizer_D.zero_grad()

                # Real loss
                pred_real = discriminator(im1 * (1 - m1) + im2 * m2, im2)
                loss_real = criterion_GAN(pred_real, valid)
                # Fake loss
                pred_fake = discriminator(fake_im.detach(), im2)
                loss_fake = criterion_GAN(pred_fake, fake)
                # Total loss
                loss_D = 0.5 * (loss_real + loss_fake)

                loss_D.backward()
                optimizer_D.step()

            if i % config.sample_interval == 0:
                msg = "Train || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                    (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                sys.stdout.write("Epoch: %d || Batch: %d\n" % (epoch, i))
                sys.stdout.write(msg)
                fname = os.path.join(
                    config.save_path,
                    "Train_Epoch:%d_Batch:%d.png" % (epoch, i))
                sample_images([im1[0], im2[0], fake_im[0]],
                              ["img1", "img2", "img1+img2"], fname)
                for j, (im1, m1, im2, m2) in enumerate(val_loader):
                    with torch.no_grad():
                        valid = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                         requires_grad=False)
                        fake = Variable(torch.Tensor(
                            np.ones((im1.size(0), *patch))).to(config.device),
                                        requires_grad=False)

                        # GAN loss
                        z = resnet(im2 * m2)
                        if epoch < config.gan_epochs:
                            fake_im = generator(im1 * (1 - m1), im2 * m2, z)
                        else:
                            fake_im = generator(im1, im2, z)

                        if epoch < config.gan_epochs:
                            pred_fake = discriminator(fake_im, im2)
                            gan_loss = config.lambda_gan * criterion_GAN(
                                pred_fake, valid)
                        else:
                            gan_loss = torch.Tensor([0]).to(config.device)

                        # Hair, Face loss
                        fake_m2 = torch.argmax(
                            model(fake_im),
                            1).unsqueeze(1).type(torch.uint8).repeat(
                                1, 3, 1, 1).to(config.device)
                        if 0.5 * torch.sum(m1) <= torch.sum(
                                fake_m2) <= 1.5 * torch.sum(m1):
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * fake_m2, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * fake_m2, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        else:
                            hair_loss = config.lambda_style * calc_style_loss(
                                fake_im * m1, im2 * m2,
                                vgg) + calc_content_loss(
                                    fake_im * m1, im2 * m2, vgg)
                            face_loss = calc_content_loss(fake_im, im1, vgg)
                        hair_loss *= config.lambda_hair
                        face_loss *= config.lambda_face

                        # Total loss
                        loss = gan_loss + hair_loss + face_loss

                        msg = "Validation || Gan loss: %.6f, hair loss: %.6f, face loss: %.6f, loss: %.6f\n" % \
                                (gan_loss.item(), hair_loss.item(), face_loss.item(), loss.item())
                        sys.stdout.write(msg)
                        fname = os.path.join(
                            config.save_path,
                            "Validation_Epoch:%d_Batch:%d.png" % (epoch, i))
                        sample_images([im1[0], im2[0], fake_im[0]],
                                      ["img1", "img2", "img1+img2"], fname)
                        break

        if epoch % config.checkpoint_interval == 0:
            if epoch < config.gan_epochs:
                models = [resnet, generator, discriminator]
                fnames = ['resnet', 'generator', 'discriminator']
            else:
                models = [resnet, generator]
                fnames = ['resnet', 'generator']
            fnames = [
                os.path.join(config.checkpoints,
                             'epoch_%d_%s.pth' % (epoch, s)) for s in fnames
            ]
            save_weights(models, fnames)
Beispiel #6
0
    torch.cuda.manual_seed_all(args.seed)

    #======================================================================
    # データセットを読み込み or 生成
    # データの前処理
    #======================================================================
    df_test = DogsVSCatsDataset( args, args.dataset_dir, "test", enable_da = False )
    dloader_test = DogsVSCatsDataLoader(df_test, batch_size=args.batch_size, shuffle=False, n_workers = args.n_workers )

    #======================================================================
    # モデルの構造を定義する。
    #======================================================================
    if( args.network_type[0] == "my_resnet18" ):
        model1 = MyResNet18( n_in_channels = 3, n_fmaps = 64, n_classes = 2 ).to(device)
    elif( args.network_type[0] == "resnet18" ):
        model1 = ResNet18( n_classes = 2, pretrained = False, train_only_fc = False ).to(device)
    else:
        model1 = ResNet50( n_classes = 2, pretrained = False, train_only_fc = False ).to(device)
 
    if( args.network_type[1] == "my_resnet18" ):
        model2 = MyResNet18( n_in_channels = 3, n_fmaps = 64, n_classes = 2 ).to(device)
    elif( args.network_type[1] == "resnet18" ):
        model2 = ResNet18( n_classes = 2, pretrained = False, train_only_fc = False ).to(device)
    else:
        model2 = ResNet50( n_classes = 2, pretrained = False, train_only_fc = False ).to(device)
        
    if( args.network_type[1] == "my_resnet18" ):
        model3 = MyResNet18( n_in_channels = 3, n_fmaps = 64, n_classes = 2 ).to(device)
    elif( args.network_type[1] == "resnet18" ):
        model3 = ResNet18( n_classes = 2, pretrained = False, train_only_fc = True ).to(device)
    else:
        board_train = SummaryWriter(
            log_dir=os.path.join(args.tensorboard_dir, args.exper_name +
                                 "_kfold{}".format(fold_id)))
        board_test = SummaryWriter(log_dir=os.path.join(
            args.tensorboard_dir, args.exper_name +
            "_kfold{}".format(fold_id) + "_test"))

        #======================================================================
        # モデルの構造を定義する。
        #======================================================================
        if (args.network_type == "my_resnet18"):
            model = MyResNet18(n_in_channels=3,
                               n_fmaps=args.n_fmaps,
                               n_classes=2).to(device)
        elif (args.network_type == "resnet18"):
            model = ResNet18(n_classes=2,
                             pretrained=args.pretrained).to(device)
        else:
            model = ResNet50(n_classes=2,
                             pretrained=args.pretrained).to(device)

        # モデルを読み込む
        if not args.load_checkpoints_path == '' and os.path.exists(
                args.load_checkpoints_path):
            load_checkpoint(model, device, args.load_checkpoints_path)

        #======================================================================
        # optimizer の設定
        #======================================================================
        optimizer = optim.Adam(params=model.parameters(),
                               lr=args.lr,
                               betas=(args.beta1, args.beta2))
                                args.dataset_dir,
                                "test",
                                enable_da=False)
    dloader_test = DogsVSCatsDataLoader(df_test,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        n_workers=args.n_workers)

    #======================================================================
    # モデルの構造を定義する。
    #======================================================================
    if (args.network_type[0] == "my_resnet18"):
        resnet = MyResNet18(n_in_channels=3, n_fmaps=64,
                            n_classes=2).to(device)
    elif (args.network_type[0] == "resnet18"):
        resnet = ResNet18(n_classes=2, pretrained=False,
                          train_only_fc=False).to(device)
    else:
        resnet = ResNet50(n_classes=2, pretrained=False,
                          train_only_fc=False).to(device)

    resnet_classifier = ImageClassifierPyTorch(device,
                                               resnet,
                                               debug=args.debug)
    resnet_classifier.load_check_point(args.load_checkpoints_path[0])

    knn_classifier = ImageClassifierSklearn(
        KNeighborsClassifier(n_neighbors=2, p=2, metric='minkowski',
                             n_jobs=-1),
        debug=args.debug,
    )
Beispiel #9
0
                                args.dataset_dir,
                                "test",
                                enable_da=False)
    dloader_test = DogsVSCatsDataLoader(df_test,
                                        batch_size=args.batch_size,
                                        shuffle=False,
                                        n_workers=args.n_workers)

    #======================================================================
    # モデルの構造を定義する。
    #======================================================================
    if (args.network_type == "my_resnet18"):
        model = MyResNet18(n_in_channels=3, n_fmaps=args.n_fmaps,
                           n_classes=2).to(device)
    elif (args.network_type == "resnet18"):
        model = ResNet18(n_classes=2, pretrained=False).to(device)
    else:
        model = ResNet50(n_classes=2, pretrained=False).to(device)

    if (args.debug):
        print("model :\n", model)

    # モデルを読み込む
    if not args.load_checkpoints_path == '' and os.path.exists(
            args.load_checkpoints_path):
        load_checkpoint(model, device, args.load_checkpoints_path)
        print("load check points")

    #======================================================================
    # モデルの推論処理
    #======================================================================
Beispiel #10
0
                               shuffle=True)

    dloader_test = DataLoader(dataset=ds_test,
                              batch_size=args.batch_size_test,
                              shuffle=False)

    if (args.debug):
        print("ds_train :\n", ds_train)
        print("ds_test :\n", ds_test)

    #======================================================================
    # モデルの構造を定義する。
    #======================================================================
    if (args.dataset == "mnist"):
        model = ResNet18(n_in_channels=1,
                         n_fmaps=args.n_fmaps,
                         n_classes=args.n_classes).to(device)
    else:
        model = ResNet18(n_in_channels=3,
                         n_fmaps=args.n_fmaps,
                         n_classes=args.n_classes).to(device)

    if (args.debug):
        print("model :\n", model)

    # モデルを読み込む
    if not args.load_checkpoints_dir == '' and os.path.exists(
            args.load_checkpoints_dir):
        init_step = load_checkpoint(
            model, device,
            os.path.join(args.load_checkpoints_dir, "model_final.pth"))