Ejemplo n.º 1
0
def build_tools(model):
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=cfg.WARMUP_LR,
                                weight_decay=cfg.WEIGHT_DECAY,
                                momentum=cfg.MOMENTUM)

    schedule_helper = CosineLRScheduler(lr_warmup_init=cfg.WARMUP_LR,
                                        base_lr=cfg.BASE_LR,
                                        lr_warmup_step=cfg.STEPS_PER_EPOCH,
                                        total_steps=cfg.TOTAL_STEPS)

    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda step: schedule_helper.get_lr_coeff(step))

    criterion = DetectionLoss(cfg.ALPHA, cfg.GAMMA, cfg.DELTA,
                              cfg.BOX_LOSS_WEIGHT, cfg.NUM_CLASSES)

    ema_decay = ExponentialMovingAverage(model, cfg.MOVING_AVERAGE_DECAY)

    return optimizer, scheduler, criterion, ema_decay
Ejemplo n.º 2
0
# 整合模型(训练生成器)
g_model.trainable = True
d_model.trainable = False
x_fake_score = d_model(g_model(z_in))

g_train_model = Model(z_in, x_fake_score)
g_train_model.add_loss(K.mean(-K.log(x_fake_score + 1e-9)))
g_train_model.compile(optimizer=Adam(2e-4, 0.5))

# 检查模型结构
d_train_model.summary()
g_train_model.summary()

# EMA
if EMA:
    EMAer_g_train = ExponentialMovingAverage(g_train_model,
                                             0.999)  # 在模型compile之后执行
    EMAer_g_train.inject()  # 在模型compile之后执行

# 训练
for i in range(total_iter):
    for j in range(1):
        next_batch = next(img_generator)
        z_sample = np.random.randn(len(next_batch), z_dim)
        d_loss = d_train_model.train_on_batch([next_batch, z_sample], None)
    for j in range(2):
        z_sample = np.random.randn(batch_size, z_dim)
        g_loss = g_train_model.train_on_batch(z_sample, None)
        if EMA:
            EMAer_g_train.ema_on_batch()
    if i % 10 == 0:
        print('iter: %s, d_loss: %s, g_loss: %s' % (i, d_loss, g_loss))
Ejemplo n.º 3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name', type=str, required=True)
    parser.add_argument('--batch_size', type=int, default=2048)
    parser.add_argument('--epochs', type=int, default=360)
    parser.add_argument('--warmup_epochs', type=int, default=5)
    parser.add_argument('--base_lr', type=float, default=2.6)
    parser.add_argument('--init_lr', type=float, default=0.0)
    parser.add_argument('--initial_epoch', type=int, default=0)
    parser.add_argument('--imagenet_path', type=str, required=True)
    parser.add_argument('--image_size', type=int, default=224)
    parser.add_argument('--use_cache', action='store_true')
    parser.add_argument('--checkpoint_path', type=str, required=True)
    parser.add_argument('--resume', type=str, default='')
    parser.add_argument('--save_every_epoch', action='store_true')
    parser.add_argument('--use_ema', action='store_true')
    args = parser.parse_args()
    print(args)

    gpus = tf.config.experimental.list_physical_devices('GPU')
    if gpus:
        try:
            for gpu in gpus:
                tf.config.experimental.set_memory_growth(gpu, True)
        except RuntimeError as e:
            exit()

    steps_per_epoch = 1281167 // args.batch_size
    strategy = tf.distribute.MirroredStrategy()
    with strategy.scope():
        if args.resume:
            model = tf.keras.models.load_model(args.resume)
        else:
            net = import_module(f'models.{args.model_name}')
            model = net.get_model()
    model.summary()
    train_dataset = imagenet.get_train_dataset(
        args.imagenet_path,
        args.batch_size,
        imagenet.NormalizeMethod.TF,
        use_color_jitter=True,
        use_one_hot=True,
        use_cache=args.use_cache,
        image_size=args.image_size).repeat().prefetch(10)
    val_dataset = imagenet.get_val_dataset(args.imagenet_path,
                                           args.batch_size,
                                           imagenet.NormalizeMethod.TF,
                                           use_one_hot=True,
                                           use_cache=args.use_cache,
                                           image_size=args.image_size)

    now_time = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
    path = os.path.join(args.checkpoint_path, f'{args.model_name}-{now_time}')
    os.makedirs(path, exist_ok=True)
    if args.save_every_epoch:
        saved_model_file = f'model.{{epoch:03d}}.h5'
    else:
        saved_model_file = 'model.best.h5'
    filepath = os.path.join(path, saved_model_file)
    callbacks = [
        LearningRateScheduler(epochs=args.epochs,
                              warmup_epochs=args.warmup_epochs,
                              steps_per_epoch=steps_per_epoch,
                              base_lr=args.base_lr,
                              init_lr=args.init_lr,
                              initial_epoch=args.initial_epoch),
        ModelCheckpoint(filepath,
                        monitor='val_categorical_accuracy',
                        verbose=0,
                        save_best_only=(not args.save_every_epoch),
                        save_weights_only=False,
                        mode='auto',
                        period=1),
        TensorBoard(f'{path}/logs')
    ]
    if args.use_ema:
        callbacks.append(ExponentialMovingAverage())

    model.fit(train_dataset,
              epochs=args.epochs,
              steps_per_epoch=steps_per_epoch,
              shuffle=False,
              validation_data=val_dataset,
              callbacks=callbacks,
              initial_epoch=args.initial_epoch)

    model.save(f'{path}/model.h5')
Ejemplo n.º 4
0
def train(num_gpus, rank, group_name, output_directory, epochs, learning_rate,
          iters_per_checkpoint, iters_per_eval, batch_size, seed, checkpoint_path, log_dir, ema_decay=0.9999):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    #=====END:   ADDED FOR DISTRIBUTED======

    if train_data_config["no_chunks"]:
        criterion = MaskedCrossEntropyLoss()
    else:
        criterion = CrossEntropyLoss()
    model = WaveNet(**wavenet_config).cuda()
    ema = ExponentialMovingAverage(ema_decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    #=====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    #=====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=200000, gamma=0.5)

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, scheduler, iteration, ema = load_checkpoint(checkpoint_path, model,
                                                                      optimizer, scheduler, ema)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2SampOnehot(audio_config=audio_config, verbose=True, **train_data_config)
    validset = Mel2SampOnehot(audio_config=audio_config, verbose=False, **valid_data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    valid_sampler = DistributedSampler(validset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    print(train_data_config)
    if train_data_config["no_chunks"]:
        collate_fn = utils.collate_fn
    else:
        collate_fn = torch.utils.data.dataloader.default_collate
    train_loader = DataLoader(trainset, num_workers=1, shuffle=False,
                              collate_fn=collate_fn,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=True,
                              drop_last=True)
    valid_loader = DataLoader(validset, num_workers=1, shuffle=False,
                              sampler=valid_sampler, batch_size=1, pin_memory=True)
    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)
    
    model.train()
    epoch_offset = max(0, int(iteration / len(train_loader)))
    writer = SummaryWriter(log_dir)
    print("Checkpoints writing to: {}".format(log_dir))
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print("Epoch: {}".format(epoch))
        for i, batch in enumerate(train_loader):
            if low_memory:
                torch.cuda.empty_cache()
            scheduler.step()
            model.zero_grad()

            if train_data_config["no_chunks"]:
                x, y, seq_lens = batch
                seq_lens = to_gpu(seq_lens)
            else:
                x, y = batch
            x = to_gpu(x).float()
            y = to_gpu(y)
            x = (x, y)  # auto-regressive takes outputs as inputs
            y_pred = model(x)
            if train_data_config["no_chunks"]:
                loss = criterion(y_pred, y, seq_lens)
            else:
                loss = criterion(y_pred, y)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus)[0]
            else:
                reduced_loss = loss.data[0]
            loss.backward()
            optimizer.step()

            for name, param in model.named_parameters():
                if name in ema.shadow:
                    ema.update(name, param.data)

            print("{}:\t{:.9f}".format(iteration, reduced_loss))
            if rank == 0:
                writer.add_scalar('loss', reduced_loss, iteration)
            if (iteration % iters_per_checkpoint == 0 and iteration):
                if rank == 0:
                    checkpoint_path = "{}/wavenet_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, scheduler, learning_rate, iteration,
                                    checkpoint_path, ema, wavenet_config)
            if (iteration % iters_per_eval == 0 and iteration > 0 and not config["no_validation"]):
                if low_memory:
                    torch.cuda.empty_cache()
                if rank == 0:
                    model_eval = nv_wavenet.NVWaveNet(**(model.export_weights()))
                    for j, valid_batch in enumerate(valid_loader):
                        mel, audio = valid_batch
                        mel = to_gpu(mel).float()
                        cond_input = model.get_cond_input(mel)
                        predicted_audio = model_eval.infer(cond_input, nv_wavenet.Impl.AUTO)
                        predicted_audio = utils.mu_law_decode_numpy(predicted_audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid/predicted_audio_{}".format(j),
                                         predicted_audio,
                                         iteration,
                                         22050)
                        audio = utils.mu_law_decode_numpy(audio[0, :].cpu().numpy(), 256)
                        writer.add_audio("valid_true/audio_{}".format(j),
                                         audio,
                                         iteration,
                                         22050)
                        if low_memory:
                            torch.cuda.empty_cache()
            iteration += 1
Ejemplo n.º 5
0
def train_fn(args):
    device = torch.device("cuda" if args.use_cuda else "cpu")
    upsample_factor = int(args.frame_shift_ms / 1000 * args.sample_rate)

    model = create_model(args)
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
    for state in optimizer.state.values():
        for key, value in state.items():
            if torch.is_tensor(value):
                state[key] = value.to(device)

    if args.resume is not None:
        print("Resume checkpoint from: {}:".format(args.resume))
        checkpoint = torch.load(args.resume,
                                map_location=lambda storage, loc: storage)
        if torch.cuda.device_count() > 1:
            model.module.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint["optimizer"])
        global_step = checkpoint['steps']
    else:
        global_step = 0

    print("receptive field: {0} ({1:.2f}ms)".format(
        model.receptive_field,
        model.receptive_field / args.sample_rate * 1000))

    if args.feature_type == "mcc":
        # mfccs have already been scaled for Ryan
        # scaler = StandardScaler()
        # scaler.mean_ = np.load(os.path.join(args.data_dir, 'mean.npy'))
        # scaler.scale_ = np.load(os.path.join(args.data_dir, 'scale.npy'))
        # feat_transform = transforms.Compose([lambda x: scaler.transform(x)])
        feat_transform = None
    else:
        feat_transform = None

    dataset = FilterbankDataset(
        data_dir=args.data_dir,
        receptive_field=model.receptive_field,
        sample_size=args.sample_size,
        upsample_factor=upsample_factor,
        quantization_channels=args.quantization_channels,
        use_local_condition=args.use_local_condition,
        noise_injecting=args.noise_injecting,
        feat_transform=feat_transform)

    dataloader = DataLoader(dataset,
                            batch_size=args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers,
                            pin_memory=True)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    criterion = nn.CrossEntropyLoss()

    ema = ExponentialMovingAverage(args.ema_decay)
    for name, param in model.named_parameters():
        if param.requires_grad:
            ema.register(name, param.data)

    while global_step < args.training_steps:
        for i, data in enumerate(dataloader, 0):
            audio, target, local_condition = data
            target = target.squeeze(-1)
            local_condition = local_condition.transpose(1, 2)
            audio, target, h = audio.to(device), target.to(
                device), local_condition.to(device)

            optimizer.zero_grad()
            output = model(audio[:, :-1, :], h[:, :, 1:])
            loss = criterion(output, target)
            print('step [%3d]: loss: %.3f' % (global_step, loss.item()))

            loss.backward()
            optimizer.step()

            # update moving average
            if ema is not None:
                apply_moving_average(model, ema)

            global_step += 1

            if global_step % args.checkpoint_interval == 0:
                save_checkpoint(device, args, model, optimizer, global_step,
                                args.checkpoint_dir, ema)
                out = output[1, :, :]
                samples = out.argmax(0)
                waveform = mu_law_decode(
                    np.asarray(samples[model.receptive_field:]),
                    args.quantization_channels)
                write_wav(
                    waveform, args.sample_rate,
                    os.path.join(args.checkpoint_dir,
                                 "train_eval_{}.wav".format(global_step)))
Ejemplo n.º 6
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size',
                        type=int,
                        default=64,
                        help="batch size")
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--resume',
                        '-r',
                        action='store_true',
                        help='resume from checkpoint')
    parser.add_argument('--fold',
                        '-f',
                        type=int,
                        default=0,
                        help='which fold you gonna train with')
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--multi-eval', type=bool, default=False)
    parser.add_argument('--update-freq', type=int, default=1)
    args = parser.parse_args()

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    if args.seed is None:
        args.seed = np.random.randint(100000)

    print("seed: {}".format(args.seed))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if device.type == 'cuda':
        torch.cuda.manual_seed(args.seed)

    DATASET = 'tiny_imagenet-known-20-split'
    # MODEL = 'custom_classifier_9'
    MODEL = 'hybrid'
    fold_num = args.fold
    batch_size = args.batch_size
    is_train = False
    is_write = False

    start_time = datetime.datetime.now().strftime('%Y-%m-%d_%I-%M-%S-%p')
    runs = 'runs/{}-{}{}-{}'.format(MODEL, DATASET, fold_num, start_time)
    if is_write:
        writer = SummaryWriter(runs)

    closed_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))
    closed_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}a.dataset'.format(DATASET, fold_num))

    open_trainloader = FlexibleCustomDataloader(
        fold='train',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))
    open_testloader = FlexibleCustomDataloader(
        fold='test',
        batch_size=batch_size,
        dataset='./data/{}{}b.dataset'.format(DATASET, fold_num))

    batch_time = RunningAverageMeter(0.97)
    bpd_meter = RunningAverageMeter(0.97)
    logpz_meter = RunningAverageMeter(0.97)
    deltalogp_meter = RunningAverageMeter(0.97)
    firmom_meter = RunningAverageMeter(0.97)
    secmom_meter = RunningAverageMeter(0.97)
    gnorm_meter = RunningAverageMeter(0.97)
    ce_meter = RunningAverageMeter(0.97)

    PATH = '{}/{}{}_hybrid'.format(runs, DATASET, fold_num)
    if is_train:
        encoder = encoder32()
        encoder.to(device)
        encoder.train()

        flow = ResidualFlow(n_classes=20,
                            input_size=(64, 128, 4, 4),
                            n_blocks=[32, 32, 32],
                            intermediate_dim=512,
                            factor_out=False,
                            quadratic=False,
                            init_layer=None,
                            actnorm=True,
                            fc_actnorm=False,
                            dropout=0,
                            fc=False,
                            coeff=0.98,
                            vnorms='2222',
                            n_lipschitz_iters=None,
                            sn_atol=1e-3,
                            sn_rtol=1e-3,
                            n_power_series=None,
                            n_dist='poisson',
                            n_samples=1,
                            kernels='3-1-3',
                            activation_fn='swish',
                            fc_end=True,
                            n_exact_terms=2,
                            preact=True,
                            neumann_grad=True,
                            grad_in_forward=False,
                            first_resblock=True,
                            learn_p=False,
                            classification='hybrid',
                            classification_hdim=256,
                            block_type='resblock')
        flow.to(device)
        flow.train()

        classifier = classifier32()
        classifier.to(device)
        classifier.train()

        ema = ExponentialMovingAverage(flow)

        flow.train()

        criterion = nn.CrossEntropyLoss()
        # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
        optimizer = optim.Adam(encoder.parameters(), lr=0.0001)
        optimizer_2 = optim.Adam(flow.parameters(), lr=0.0001)
        optimizer_3 = optim.SGD(classifier.parameters(), lr=0.1, momentum=0.9)
        # optimizer_3 = optim.Adam(classifier.parameters(), lr=0.0001)

        # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        #                                                  milestones=[50, 100, 150, 200, 250, 300, 350, 400, 450],
        #                                                  gamma=0.1)
        beta = 1
        running_loss = 0.0
        running_bpd = 0.0
        running_cls = 0.0
        best_loss = 1000
        tau = 100000
        for epoch in range(600):
            for i, (images, labels) in enumerate(closed_trainloader, 0):
                global_itr = epoch * len(closed_trainloader) + i
                images = Variable(images)
                images = images.cuda()

                labels = Variable(labels)

                # writer.add_graph(net, images)
                outputs = encoder(images)

                bpd, logits, logpz, neg_delta_logp = compute_loss(outputs,
                                                                  flow,
                                                                  beta=beta)
                cls_outputs = classifier(outputs)

                labels = torch.argmax(labels, dim=1)
                cls_loss = criterion(cls_outputs, labels)

                firmom, secmom = estimator_moments(flow)

                bpd_meter.update(bpd.item())
                logpz_meter.update(logpz.item())
                deltalogp_meter.update(neg_delta_logp.item())
                firmom_meter.update(firmom)
                secmom_meter.update(secmom)

                loss = bpd + cls_loss
                #
                # loss.backward()
                #
                # labels = torch.argmax(labels, dim=1)
                #
                # # writer.add_embedding(outputs, metadata=class_labels, label_img=images.unsqueeze(1))
                # loss = criterion(outputs, labels)
                loss.backward()

                if global_itr % args.update_freq == args.update_freq - 1:
                    if args.update_freq > 1:
                        with torch.no_grad():
                            for p in flow.parameters():
                                if p.grad is not None:
                                    p.grad /= args.update_freq

                    grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_(
                        flow.parameters(), 1.)

                    optimizer.step()
                    optimizer_2.step()
                    optimizer_3.step()

                    optimizer.zero_grad()
                    optimizer_2.zero_grad()
                    optimizer_3.zero_grad()

                    update_lipschitz(flow)
                    ema.apply()
                    gnorm_meter.update(grad_norm)

                running_bpd += bpd.item()
                running_cls += cls_loss.item()
                running_loss += loss.item()

                if i % 100 == 99:
                    if is_write:
                        writer.add_scalar('bits per dimension',
                                          running_bpd / 100, global_itr)
                        writer.add_scalar('classification loss',
                                          running_cls / 100, global_itr)
                        writer.add_scalar('total loss', running_loss / 100,
                                          global_itr)
                    current_time = datetime.datetime.now().strftime(
                        '%Y-%m-%d_%I-%M-%S-%p')
                    print(current_time)
                    print(
                        '[%d, %5d] bpd: %.3f, cls_loss: %.3f, total_loss: %.3f'
                        % (epoch + 1, i + 1, running_bpd / 100,
                           running_cls / 100, running_loss / 100))
                    if epoch > 1 and running_loss / 100 < best_loss:
                        best_loss = running_loss / 100
                        print("best loss updated! :", best_loss)
                        torch.save(
                            {
                                'state_dict': flow.state_dict(),
                                'optimizer_state_dict': optimizer.state_dict(),
                                'args': args,
                                'ema': ema,
                            }, "{}_flow_best.pth".format(PATH))

                        torch.save(encoder.state_dict(),
                                   "{}_encoder_best.pth".format(PATH))
                        torch.save(classifier.state_dict(),
                                   "{}_classifier_best.pth".format(PATH))

                    # writer.add_figure('predictions vs. actuals',
                    #                   plot_classes_preds(net, images, labels))
                    running_loss = 0.0
                    running_bpd = 0.0
                    running_cls = 0.0

                del images
                torch.cuda.empty_cache()
                gc.collect()

            if epoch % 50 == 49:
                torch.save(
                    {
                        'state_dict': flow.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'args': args,
                        'ema': ema,
                    }, "{}_flow_{}.pth".format(PATH, epoch + 1))

                torch.save(encoder.state_dict(),
                           "{}_encoder_{}.pth".format(PATH, epoch + 1))
                torch.save(classifier.state_dict(),
                           "{}_classifier_{}.pth".format(PATH, epoch + 1))

    PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM"
    PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num)

    if args.multi_eval:
        for i in range(50, 550, 50):
            test_encoder = encoder32()
            test_encoder.to(device)
            test_encoder.load_state_dict(
                torch.load("{}_encoder_{}.pth".format(PATH, i)))
            # state_dict = torch.load("{}_encoder_{}.pth".format(PATH, i))
            # # create new OrderedDict that does not contain `module.`
            #
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = k[7:]  # remove `module.`
            #     new_state_dict[name] = v
            # # load params
            # test_encoder.load_state_dict(new_state_dict)

            test_classifier = classifier32()
            test_classifier.to(device)
            # state_dict = torch.load("{}_classifier_{}.pth".format(PATH, i))
            # # create new OrderedDict that does not contain `module.`
            #
            # new_state_dict = OrderedDict()
            # for k, v in state_dict.items():
            #     name = k[7:]  # remove `module.`
            #     new_state_dict[name] = v
            # # load params
            # test_classifier.load_state_dict(new_state_dict)
            test_classifier.load_state_dict(
                torch.load("{}_classifier_{}.pth".format(PATH, i)))

            test_flow = ResidualFlow(n_classes=20,
                                     input_size=(64, 128, 4, 4),
                                     n_blocks=[32, 32, 32],
                                     intermediate_dim=512,
                                     factor_out=False,
                                     quadratic=False,
                                     init_layer=None,
                                     actnorm=True,
                                     fc_actnorm=False,
                                     dropout=0,
                                     fc=False,
                                     coeff=0.98,
                                     vnorms='2222',
                                     n_lipschitz_iters=None,
                                     sn_atol=1e-3,
                                     sn_rtol=1e-3,
                                     n_power_series=None,
                                     n_dist='poisson',
                                     n_samples=1,
                                     kernels='3-1-3',
                                     activation_fn='swish',
                                     fc_end=True,
                                     n_exact_terms=2,
                                     preact=True,
                                     neumann_grad=True,
                                     grad_in_forward=False,
                                     first_resblock=True,
                                     learn_p=False,
                                     classification='hybrid',
                                     classification_hdim=256,
                                     block_type='resblock')

            test_flow.to(device)

            with torch.no_grad():
                x = torch.rand(1, *input_size[1:]).to(device)
                test_flow(x)
            checkpt = torch.load("{}_flow_{}.pth".format(PATH, i))
            sd = {
                k: v
                for k, v in checkpt['state_dict'].items()
                if 'last_n_samples' not in k
            }
            state = test_flow.state_dict()
            state.update(sd)
            test_flow.load_state_dict(state, strict=True)
            # test_ema.set(checkpt['ema'])

            hybrid = HybridModel(test_encoder, test_classifier, test_flow)

            closed_acc = evalute_classifier(hybrid, closed_testloader)
            print("closed-set accuracy: ", closed_acc)
            auc_d = evaluate_openset(hybrid, closed_testloader,
                                     open_testloader)
            print("auc discriminator: ", auc_d)

            result_file = '{}/{}{}.txt'.format(runs, DATASET, fold_num)

            current_time = datetime.datetime.now().strftime(
                '%Y-%m-%d_%I-%M-%S-%p')

            if is_write:
                if os.path.exists(result_file):
                    f = open(result_file, 'a')
                    f.write(current_time + "\n")
                    f.write("seed: {}\n".format(args.seed))
                    f.write("{}{} \n".format(DATASET, fold_num))
                    f.write("{} epoch".format(i))
                    f.write("close-set accuracy: {} \n".format(closed_acc))
                    f.write("AUROC: {} \n".format(auc_d))
                    f.close()
                else:
                    f = open(result_file, 'w')
                    f.write(current_time + "\n")
                    f.write("seed: {}\n".format(args.seed))
                    f.write("{}{} \n".format(DATASET, fold_num))
                    f.write("{} epoch".format(i))
                    f.write("close-set accuracy: {} \n".format(closed_acc))
                    f.write("AUROC: {} \n".format(auc_d))
                    f.close()
    else:
        PATH_1 = "/home/taehokim/PycharmProjects/RNCl/runs/hybrid-tiny_imagenet-known-20-split0-2020-09-21_05-49-50-PM"
        PATH = "{}/{}{}_hybrid".format(PATH_1, DATASET, fold_num)

        test_encoder = encoder32()
        test_encoder.to(device)
        test_encoder.load_state_dict(
            torch.load("{}_encoder_latest.pth".format(PATH)))

        test_classifier = classifier32()
        test_classifier.to(device)
        test_classifier.load_state_dict(
            torch.load("{}_classifier_latest.pth".format(PATH)))

        test_flow = ResidualFlow(n_classes=20,
                                 input_size=(64, 128, 4, 4),
                                 n_blocks=[32, 32, 32],
                                 intermediate_dim=512,
                                 factor_out=False,
                                 quadratic=False,
                                 init_layer=None,
                                 actnorm=True,
                                 fc_actnorm=False,
                                 dropout=0,
                                 fc=False,
                                 coeff=0.98,
                                 vnorms='2222',
                                 n_lipschitz_iters=None,
                                 sn_atol=1e-3,
                                 sn_rtol=1e-3,
                                 n_power_series=None,
                                 n_dist='poisson',
                                 n_samples=1,
                                 kernels='3-1-3',
                                 activation_fn='swish',
                                 fc_end=True,
                                 n_exact_terms=2,
                                 preact=True,
                                 neumann_grad=True,
                                 grad_in_forward=False,
                                 first_resblock=True,
                                 learn_p=False,
                                 classification='hybrid',
                                 classification_hdim=256,
                                 block_type='resblock')

        test_flow.to(device)

        with torch.no_grad():
            x = torch.rand(1, *input_size[1:]).to(device)
            test_flow(x)
        checkpt = torch.load("{}_flow_latest.pth".format(PATH))
        sd = {
            k: v
            for k, v in checkpt['state_dict'].items()
            if 'last_n_samples' not in k
        }
        state = test_flow.state_dict()
        state.update(sd)
        test_flow.load_state_dict(state, strict=True)

        hybrid = HybridModel(test_encoder, test_classifier, test_flow)

        closed_acc = evalute_classifier(hybrid, closed_testloader)
        print("closed-set accuracy: ", closed_acc)
        auc_d = evaluate_openset(hybrid, closed_testloader, open_testloader)
        print("auc discriminator: ", auc_d)
Ejemplo n.º 7
0
shutil.copy('params.py', os.path.join(result, 'params.py'))
shutil.copy('generate.py', os.path.join(result, 'generate.py'))
shutil.copy('net.py', os.path.join(result, 'net.py'))
shutil.copytree('WaveNet', os.path.join(result, 'WaveNet'))

# Model
encoder = UpsampleNet(params.channels, params.upsample_factors)
wavenet = WaveNet(params.n_loop, params.n_layer, params.filter_size,
                  params.input_dim, params.residual_channels,
                  params.dilated_channels, params.skip_channels,
                  params.quantize, params.use_logistic, params.n_mixture,
                  params.log_scale_min, params.condition_dim,
                  params.dropout_zero_rate)

if params.ema_mu < 1:
    decoder = ExponentialMovingAverage(wavenet, params.ema_mu)
else:
    decoder = wavenet

if params.use_logistic:
    loss_fun = wavenet.calculate_logistic_loss
    acc_fun = None
else:
    loss_fun = chainer.functions.softmax_cross_entropy
    acc_fun = chainer.functions.accuracy
model = EncoderDecoderModel(encoder, decoder, loss_fun, acc_fun)

# Optimizer
optimizer = chainer.optimizers.Adam(params.lr / len(args.gpus))
optimizer.setup(model)