def run_experiment(fraction_train,
                   load_model=False,
                   old_runname=None,
                   start_epoch=None):

    runname = f'splitted_data_{str(fraction_train)}'

    device = Device.GPU1
    epochs = 50
    features = 64
    batch_size = 4
    all_image_size = 96
    in_chan = 15

    context = cpu() if device.value == -1 else gpu(device.value)
    # ----------------------------------------------------
    if load_model:
        summaryWriter = SummaryWriter('logs/' + old_runname, flush_secs=5)
    else:
        summaryWriter = SummaryWriter('logs/' + runname, flush_secs=5)

    train_iter = modules.make_iterator_preprocessed(
        'training',
        'V1',
        'V2',
        'V3',
        batch_size=batch_size,
        shuffle=True,
        fraction_train=fraction_train)
    test_iter = modules.make_iterator_preprocessed('testing',
                                                   'V1',
                                                   'V2',
                                                   'V3',
                                                   batch_size=batch_size,
                                                   shuffle=True)

    RFlocs_V1_overlapped_avg = modules.get_RFs('V1', context)
    RFlocs_V2_overlapped_avg = modules.get_RFs('V2', context)
    RFlocs_V3_overlapped_avg = modules.get_RFs('V3', context)

    with Context(context):
        discriminator = Discriminator(in_chan)
        generator = Generator(in_chan, context)

        if load_model:
            generator.network.load_parameters(
                f'saved_models/{old_runname}/netG_{start_epoch}.model',
                ctx=context)
            discriminator.network.load_parameters(
                f'saved_models/{old_runname}/netD_{start_epoch}.model')

        gen_lossfun = gen.Lossfun(1, 100, 1, context)
        d = discriminator.network

        dis_lossfun = dis.Lossfun(1)
        g = generator.network

        print('train_dataset_length:', len(train_iter._dataset))

        for epoch in range(epochs):

            loss_discriminator_train = []
            loss_generator_train = []

            # ====================
            # T R AI N I N G
            # ====================

            for RFsignalsV1, RFsignalsV2, RFsignalsV3, targets in tqdm(
                    train_iter, total=len(train_iter)):
                # -------
                # Inputs
                # -------
                inputs1 = modules.get_inputsROI(RFsignalsV1,
                                                RFlocs_V1_overlapped_avg,
                                                context)
                inputs2 = modules.get_inputsROI(RFsignalsV2,
                                                RFlocs_V2_overlapped_avg,
                                                context)
                inputs3 = modules.get_inputsROI(RFsignalsV3,
                                                RFlocs_V3_overlapped_avg,
                                                context)
                inputs = concat(inputs1, inputs2, inputs3, dim=1)
                # ------------------------------------
                # T R A I N  D i s c r i m i n a t o r
                # ------------------------------------
                targets = targets.as_in_context(context).transpose(
                    (0, 1, 3, 2))

                loss_discriminator_train.append(
                    discriminator.train(g, inputs, targets))

                # ----------------------------
                # T R A I N  G e n e r a t o r
                # ----------------------------
                loss_generator_train.append(generator.train(
                    d, inputs, targets))

            if load_model:
                os.makedirs('saved_models/' + old_runname, exist_ok=True)
                generator.network.save_parameters(
                    f'saved_models/{old_runname}/netG_{epoch+start_epoch+1}.model'
                )
                discriminator.network.save_parameters(
                    f'saved_models/{old_runname}/netD_{epoch+start_epoch+1}.model'
                )
            else:
                os.makedirs('saved_models/' + runname, exist_ok=True)
                generator.network.save_parameters(
                    f'saved_models/{runname}/netG_{epoch}.model')
                discriminator.network.save_parameters(
                    f'saved_models/{runname}/netD_{epoch}.model')

            # ====================
            # T E S T I N G
            # ====================
            loss_discriminator_test = []
            loss_generator_test = []

            for RFsignalsV1, RFsignalsV2, RFsignalsV3, targets in test_iter:
                # -------
                # Inputs
                # -------
                inputs1 = modules.get_inputsROI(RFsignalsV1,
                                                RFlocs_V1_overlapped_avg,
                                                context)
                inputs2 = modules.get_inputsROI(RFsignalsV2,
                                                RFlocs_V2_overlapped_avg,
                                                context)
                inputs3 = modules.get_inputsROI(RFsignalsV3,
                                                RFlocs_V3_overlapped_avg,
                                                context)
                inputs = concat(inputs1, inputs2, inputs3, dim=1)

                # -----
                # Targets
                # -----
                targets = targets.as_in_context(context).transpose(
                    (0, 1, 3, 2))

                # ----
                # sample randomly from history buffer (capacity 50)
                # ----

                z = concat(inputs, g(inputs), dim=1)

                dis_loss_test = 0.5 * (dis_lossfun(0, d(z)) + dis_lossfun(
                    1, d(concat(inputs, targets, dim=1))))

                loss_discriminator_test.append(float(dis_loss_test.asscalar()))

                gen_loss_test = (lambda y_hat: gen_lossfun(
                    1, d(concat(inputs, y_hat, dim=1)), targets, y_hat))(
                        generator.network(inputs))

                loss_generator_test.append(float(gen_loss_test.asscalar()))

            summaryWriter.add_image(
                "input", modules.leclip(inputs.expand_dims(2).sum(1)), epoch)
            summaryWriter.add_image("target", modules.leclip(targets), epoch)
            summaryWriter.add_image("pred", modules.leclip(g(inputs)), epoch)
            summaryWriter.add_scalar(
                "dis/loss_discriminator_train",
                sum(loss_discriminator_train) / len(loss_discriminator_train),
                epoch)
            summaryWriter.add_scalar(
                "gen/loss_generator_train",
                sum(loss_generator_train) / len(loss_generator_train), epoch)

            summaryWriter.add_scalar(
                "dis/loss_discriminator_test",
                sum(loss_discriminator_test) / len(loss_discriminator_test),
                epoch)
            summaryWriter.add_scalar(
                "gen/loss_generator_test",
                sum(loss_generator_test) / len(loss_generator_test), epoch)

            # ------------------------------------------------------------------
            # T R A I N I N G Losses
            # ------------------------------------------------------------------
            np.save(f'saved_models/{runname}/Gloss_train',
                    np.array(loss_generator_train))
            np.save(f'saved_models/{runname}/Dloss_train',
                    np.array(loss_discriminator_train))
            # ------------------------------------------------------------------
            # T E S T I N G Losses
            # ------------------------------------------------------------------
            np.save(f'saved_models/{runname}/Gloss_test',
                    np.array(loss_generator_test))
            np.save(f'saved_models/{runname}/Dloss_test',
                    np.array(loss_discriminator_test))
Beispiel #2
0
    # ----------------------------------------------------
    # SummaryWriter is for visualizing logs in tensorboard
    # ----------------------------------------------------
    summaryWriter = SummaryWriter('../../logs/' + runname, flush_secs=5)

with Context(context):
    # ----------------------------------------------------
    # RF centers: overlapping
    # ----------------------------------------------------
    RFlocs_V1_overlapped_avg = modules.get_RFs('V1', context)
    RFlocs_V2_overlapped_avg = modules.get_RFs('V2', context)
    RFlocs_V3_overlapped_avg = modules.get_RFs('V3', context)

    test_iter = modules.make_iterator_preprocessed('testing',
                                                   'V1',
                                                   'V2',
                                                   'V3',
                                                   batch_size=batch_size,
                                                   shuffle=True)

    RF_signals_lengths = []
    for *RFsignals, targets in test_iter:
        for s in RFsignals:
            RF_signals_lengths.append(s.shape[2])
        break

    discriminator = Discriminator(in_chan)
    generator = Generator(in_chan,
                          context,
                          RF_in_units=RF_signals_lengths,
                          conv_input_shape=(96, 96),
                          train_RF=True)