示例#1
0
def train_gmn():

    # ==> initial check
    assert args.dataset == 'imagenet'

    # ==> gpu configuration
    ut.initialize_GPU(args)

    # ==> set up model path and log path.
    model_path, log_path = ut.set_path(args)

    # ==> import library
    import keras
    import data_loader
    import model_factory
    import data_generator

    # ==> get dataset information
    trn_config = data_loader.get_config(args)
    print('trn_config:', trn_config)
    params = {
        'cg': trn_config,
        'processes': 12,
        'batch_size': args.batch_size,
    }
    trn_gen, val_gen = data_generator.setup_generator(**params)

    # ==> load model
    gmn = model_factory.two_stream_matching_networks(trn_config)
    gmn.summary()

    # ==> attempt to load pre-trained model
    if args.resume:
        if os.path.isfile(args.resume):
            gmn.load_weights(os.path.join(args.resume), by_name=True)
            print('==> successfully loading the model: {}'.format(args.resume))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    # ==> set up callbacks, e.g. lr schedule, tensorboard, save checkpoint.
    normal_lr = keras.callbacks.LearningRateScheduler(ut.step_decay(args))
    tbcallbacks = keras.callbacks.TensorBoard(log_dir=log_path,
                                              histogram_freq=0,
                                              write_graph=False,
                                              write_images=False)
    callbacks = [
        keras.callbacks.ModelCheckpoint(os.path.join(model_path, 'model.h5'),
                                        monitor='val_loss',
                                        save_best_only=True,
                                        mode='min'), normal_lr, tbcallbacks
    ]

    gmn.fit_generator(trn_gen,
                      steps_per_epoch=600,
                      epochs=args.epochs,
                      validation_data=val_gen,
                      validation_steps=100,
                      callbacks=callbacks,
                      verbose=1)
def adapt_gmn():

    # ==> gpu configuration
    ut.initialize_GPU(args)

    # ==> set up model path and log path.
    model_path, log_path = ut.set_path(args)

    # ==> import library
    import keras
    import data_loader
    import model_factory
    import data_generator

    # ==> get dataset information
    trn_config = data_loader.get_config(args)

    params = {'cg': trn_config, 'processes': 12, 'batch_size': args.batch_size}

    trn_gen, val_gen = data_generator.setup_generator(**params)

    # ==> load networks
    gmn = model_factory.two_stream_matching_networks(trn_config,
                                                     sync=False,
                                                     adapt=False)
    model = model_factory.two_stream_matching_networks(trn_config,
                                                       sync=False,
                                                       adapt=True)

    # ==> attempt to load pre-trained model
    if args.resume:
        if os.path.isfile(args.resume):
            model.load_weights(os.path.join(args.resume), by_name=True)
            print('==> successfully loading the model: {}'.format(args.resume))
        else:
            print("==> no checkpoint found at '{}'".format(args.resume))

    # ==> attempt to load pre-trained GMN
    elif args.gmn_path:
        if os.path.isfile(args.gmn_path):
            gmn.load_weights(os.path.join(args.gmn_path), by_name=True)
            print('==> successfully loading the model: {}'.format(
                args.gmn_path))
        else:
            print("==> no checkpoint found at '{}'".format(args.gmn_path))

    # ==> print model summary
    model.summary()

    # ==> transfer weights from gmn to new model (this step is slow, but can't seem to avoid it)
    for i, layer in enumerate(gmn.layers):
        if isinstance(layer, model.__class__):
            for l in layer.layers:
                weights = l.get_weights()
                if len(weights) > 0:
                    #print('{}'.format(l.name))
                    model.layers[i].get_layer(l.name).set_weights(weights)
        else:
            weights = layer.get_weights()
            if len(weights) > 0:
                #print('{}'.format(layer.name))
                model.get_layer(layer.name).set_weights(weights)

    # ==> set up callbacks, e.g. lr schedule, tensorboard, save checkpoint.
    normal_lr = keras.callbacks.LearningRateScheduler(ut.step_decay(args))
    tbcallbacks = keras.callbacks.TensorBoard(log_dir=log_path,
                                              histogram_freq=0,
                                              write_graph=False,
                                              write_images=False)
    callbacks = [
        keras.callbacks.ModelCheckpoint(os.path.join(model_path, 'model.h5'),
                                        monitor='val_loss',
                                        save_best_only=True,
                                        mode='min'), normal_lr, tbcallbacks
    ]

    model.fit_generator(trn_gen,
                        steps_per_epoch=600,
                        epochs=args.epochs,
                        validation_data=val_gen,
                        validation_steps=100,
                        callbacks=callbacks,
                        verbose=1)
def train(model_path):
    config_values = "model" + hp.model.type + "_proj" + str(hp.model.proj) + "_vlad" + str(hp.model.vlad_centers) \
                    + "_ghost" + str(hp.model.ghost_centers) + "_spk" + str(hp.train.N) + "_utt" + str(hp.train.M) \
                    + "_dropout" + str(hp.model.dropout) + "_feat" + hp.data.feat_type + "_lr" + str(hp.train.lr) \
                    + "_optim" + hp.train.optim + "_loss" + hp.train.loss \
                    + "_wd" + str(hp.train.wd) + "_fr" + str(hp.data.tisv_frame)
    #checkpoint and log dir
    os.makedirs(hp.train.checkpoint_dir, exist_ok=True)
    log_file = config_values + ".log"
    log_file_path = os.path.join(hp.train.checkpoint_dir, log_file)

    #load model
    embedder_net = Resnet34_VLAD()
    embedder_net = torch.nn.DataParallel(embedder_net)
    embedder_net = embedder_net.cuda()
    print(embedder_net)

    #load dataset
    train_dataset = VoxCeleb_utter()
    train_loader = DataLoader(train_dataset,
                              batch_size=hp.train.N,
                              shuffle=True,
                              num_workers=hp.train.num_workers,
                              drop_last=True)
    loss_fn = SILoss(hp.model.proj, train_dataset.num_of_spk).cuda()

    if hp.train.restore:
        embedder_net.load_state_dict(
            torch.load(os.path.join(hp.train.checkpoint_dir, model_path)))
        loss_fn.load_state_dict(
            torch.load(
                os.path.join(hp.train.checkpoint_dir, "loss_" + model_path)))
    #Both net and loss have trainable parameters

    if hp.train.optim.lower() == 'sgd':
        optimizer = torch.optim.SGD([{
            'params': embedder_net.parameters()
        }, {
            'params': loss_fn.parameters()
        }],
                                    lr=hp.train.lr,
                                    weight_decay=hp.train.wd)
    elif hp.train.optim.lower() == 'adam':
        optimizer = torch.optim.Adam([{
            'params': embedder_net.parameters()
        }, {
            'params': loss_fn.parameters()
        }],
                                     lr=hp.train.lr,
                                     weight_decay=hp.train.wd)
    elif hp.train.optim.lower() == 'adadelta':
        optimizer = torch.optim.Adadelta([{
            'params': embedder_net.parameters()
        }, {
            'params': loss_fn.parameters()
        }],
                                         lr=hp.train.lr,
                                         weight_decay=hp.train.wd)

    print(optimizer)
    iteration = 0
    for e in range(hp.train.epochs):
        step_decay(e, optimizer)  #stage based lr scheduler
        total_loss = 0

        for batch_id, (mel_db_batch, spk_id) in enumerate(train_loader):
            embedder_net.train().cuda()
            mel_db_batch = mel_db_batch.cuda()
            spk_id = spk_id.cuda()
            mel_db_batch = torch.reshape(
                mel_db_batch, (hp.train.N * hp.train.M, mel_db_batch.size(2),
                               mel_db_batch.size(3)))
            optimizer.zero_grad()
            embeddings = embedder_net(mel_db_batch)
            #get loss, call backward, step optimizer
            loss, _ = loss_fn(embeddings,
                              spk_id)  #wants (Speaker, Utterances, embedding)
            loss.backward()
            optimizer.step()
            total_loss = total_loss + loss
            iteration += 1

            if (batch_id + 1) % hp.train.log_interval == 0 or \
               (batch_id + 1) % (len(train_dataset)//hp.train.N) == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}], Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(
                    time.ctime(), e + 1, batch_id + 1,
                    len(train_dataset) // hp.train.N, iteration, loss,
                    total_loss / (batch_id + 1))
                print(mesg)
                with open(log_file_path, 'a') as f:
                    f.write(mesg)

                if (batch_id + 1) % (len(train_dataset) // hp.train.N) == 0:
                    #scheduler.step(total_loss) # uncommenr for ReduceLROnPlateau scheduler
                    print("learning rate: {0:.6f}\n".format(
                        optimizer.param_groups[1]['lr']))

        # calculate accuracy on validation set
        if hp.train.checkpoint_dir is not None and (
                e + 1) % hp.train.checkpoint_interval == 0:
            # switch model to evaluation mode
            embedder_net.eval()

            ckpt_model_filename = config_values + '.pth'
            ckpt_model_path = os.path.join(hp.train.checkpoint_dir,
                                           ckpt_model_filename)
            ckpt_loss_path = os.path.join(hp.train.checkpoint_dir,
                                          'loss_' + ckpt_model_filename)
            torch.save(loss_fn.state_dict(), ckpt_loss_path)
            torch.save(embedder_net.state_dict(), ckpt_model_path)

            eer, thresh = testVoxCeleb(ckpt_model_path)
            mesg = ("\nEER : %0.4f (thres:%0.2f)\n" % (eer, thresh))
            mesg += ("learning rate: {0:.8f}\n".format(
                optimizer.param_groups[1]['lr']))
            print(mesg)
            with open(log_file_path, 'a') as f:
                f.write(mesg)
示例#4
0
def train(model: keras.models.Model,
          optimizer: dict,
          save_path: str,
          train_dir: str,
          valid_dir: str,
          batch_size: int = 32,
          epochs: int = 10,
          samples_per_epoch=1000,
          pretrained=None,
          augment: bool = True,
          weight_mode=None,
          verbose=0,
          **kwargs):
    """ Trains the model with the given configurations. """
    shape = model.input_shape[1:3]
    optimizer_cpy = optimizer.copy()
    shared_gen_args = {
        'rescale': 1. / 255,  # to preserve the rgb palette
    }
    train_gen_args = {}
    if augment:
        train_gen_args = {
            "fill_mode": 'reflect',
            'horizontal_flip': True,
            'vertical_flip': True,
            'width_shift_range': .15,
            'height_shift_range': .15,
            'shear_range': .5,
            'rotation_range': 45,
            'zoom_range': .2,
        }
    gen = IDG(**{**shared_gen_args, **train_gen_args})
    gen = gen.flow_from_directory(train_dir,
                                  target_size=shape,
                                  batch_size=batch_size,
                                  seed=SEED)

    val_count = len(
        glob(os.path.join(valid_dir, '**', '*.jpg'), recursive=True))
    valid_gen = IDG(**shared_gen_args)

    optim = getattr(keras.optimizers, optimizer['name'])
    if optimizer.pop('name') != 'sgd':
        optimizer.pop('nesterov')
    schedule = optimizer.pop('schedule')
    if schedule == 'decay' and 'lr' in optimizer.keys():
        initial_lr = optimizer.pop('lr')
    else:
        initial_lr = 0.01
    optim = optim(**optimizer)

    callbacks = [
        utils.checkpoint(save_path),
        utils.csv_logger(save_path),
    ]

    if pretrained is not None:
        if not os.path.exists(pretrained):
            raise FileNotFoundError()

        model.load_weights(pretrained, by_name=False)
        if verbose == 1:
            print("Loaded weights from {}".format(pretrained))

    if optimizer_cpy['name'] == 'sgd':
        if schedule == 'decay':
            callbacks.append(utils.step_decay(epochs, initial_lr=initial_lr))
        elif schedule == 'big_drop':
            callbacks.append(utils.constant_schedule())

    model.compile(optim,
                  loss='categorical_crossentropy',
                  metrics=['accuracy', top3_acc])

    create_xml_description(save=os.path.join(save_path, 'model_config.xml'),
                           title=model.name,
                           epochs=epochs,
                           batch_size=batch_size,
                           samples_per_epoch=samples_per_epoch,
                           augmentations=augment,
                           schedule=schedule,
                           optimizer=optimizer_cpy,
                           **kwargs)

    if weight_mode:
        class_weights = [[key, value] for key, value in weight_mode.items()]
        filen = os.path.join(save_path, 'class_weights.npy')
        np.save(filen, class_weights)

    h = None  # has to be initialized here, so we can reference it later
    try:
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            h = model.fit_generator(
                gen,
                steps_per_epoch=samples_per_epoch / batch_size,
                epochs=epochs,
                validation_data=valid_gen.flow_from_directory(
                    valid_dir,
                    target_size=shape,
                    batch_size=batch_size,
                    seed=SEED),
                validation_steps=val_count / batch_size,
                callbacks=callbacks,
                class_weight=weight_mode,
                verbose=2)
    except KeyboardInterrupt:
        save_results(verbose=1, save_path=save_path, model=model, hist=h)
        return

    save_results(verbose=1, save_path=save_path, model=model, hist=h)