Exemplo n.º 1
0
def pretrain_fine(epoch, fine_id):
    net.fines[fine_id].train()

    optimizer, lr = get_optim(net.fines[fine_id].parameters(), args, mode='preTrain', epoch=epoch)

    print('==> Epoch #%d, LR=%.4f' % (epoch, lr))
    required_train_loader = get_dataLoder(args, classes=net.class_set[fine_id], mode='preTrain')
    predictor = net.fines[fine_id]
    for batch_idx, (inputs, targets) in enumerate(required_train_loader):
        if cf.use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets).long()

        outputs = predictor(net.share(inputs)) # Forward Propagation
        loss = pred_loss(outputs, targets)
        loss.backward()  # Backward Propagation
        optimizer.step() # Optimizer update

        num_ins = targets.size(0)
        _, outputs = torch.max(outputs, 1)
        correct = outputs.eq(targets.data).cpu().sum()
        acc = 100.*correct.item()/num_ins

        sys.stdout.write('\r')
        sys.stdout.write('Pre-train Epoch [%3d/%3d] Iter [%3d/%3d]\t\t Loss: %.4f Accuracy: %.3f%%'
                         %(epoch, args.num_epochs_pretrain, batch_idx+1, (required_train_loader.dataset.train_data.shape[0]//args.pretrain_batch_size)+1,
                           loss.item(), acc))
        sys.stdout.flush()
Exemplo n.º 2
0
def augment(out_dir, chkpt_path, train_loader, valid_loader, model, writer,
            logger, device, config):

    w_optim = utils.get_optim(model.weights(), config.w_optim)

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim, config.epochs, eta_min=config.w_optim.lr_min)

    init_epoch = -1

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        w_optim.load_state_dict(checkpoint['w_optim'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        init_epoch = checkpoint['epoch']
    else:
        logger.info("Starting new training run")

    logger.info("Model params count: {:.3f} M, size: {:.3f} MB".format(
        utils.param_size(model), utils.param_count(model)))

    # training loop
    logger.info('begin training')
    best_top1 = 0.
    tot_epochs = config.epochs
    for epoch in itertools.count(init_epoch + 1):
        if epoch == tot_epochs: break

        drop_prob = config.drop_path_prob * epoch / tot_epochs
        model.drop_path_prob(drop_prob)

        lr = lr_scheduler.get_lr()[0]

        # training
        train(train_loader, None, model, writer, logger, None, w_optim, None,
              lr, epoch, tot_epochs, device, config)
        lr_scheduler.step()

        # validation
        cur_step = (epoch + 1) * len(train_loader)
        top1 = validate(valid_loader, model, writer, logger, epoch, tot_epochs,
                        cur_step, device, config)

        # save
        if best_top1 < top1:
            best_top1 = top1
            is_best = True
        else:
            is_best = False

        if config.save_freq != 0 and epoch % config.save_freq == 0:
            save_checkpoint(out_dir, model, w_optim, None, lr_scheduler, epoch,
                            logger)

        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    tprof.stat_acc('model_' + NASModule.get_device()[0])
Exemplo n.º 3
0
def pretrain_coarse(epoch):
    net.share.train()
    net.coarse.train()

    param = list(net.share.parameters())+list(net.coarse.parameters())
    optimizer, lr = get_optim(param, args, mode='preTrain', epoch=epoch)

    print('\n==> Epoch #%d, LR=%.4f' % (epoch, lr))
    for batch_idx, (inputs, targets) in enumerate(pretrainloader):
        if cf.use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda() # GPU setting
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)

        outputs = net.coarse(net.share(inputs)) # Forward Propagation

        loss = pred_loss(outputs, targets)
        loss.backward()  # Backward Propagation
        optimizer.step() # Optimizer update

        _, predicted = torch.max(outputs.data, 1)
        num_ins = targets.size(0)
        correct = predicted.eq(targets.data).cpu().sum()

        sys.stdout.write('\r')
        sys.stdout.write('Pre-train Epoch [%3d/%3d] Iter [%3d/%3d]\t\t Loss: %.4f Accuracy: %.3f%%'
                         %(epoch, args.num_epochs_pretrain, batch_idx+1, (pretrainloader.dataset.train_data.shape[0]//args.pretrain_batch_size)+1,
                           loss.item(), 100.*correct.item()/num_ins))
        sys.stdout.flush()
Exemplo n.º 4
0
def train_branch(branch, clusting_result, classes):
    for epoch in range(args.num_epochs_train):
        required_train_loader = get_dataLoder(args, classes = classes, mode='Train', one_hot=True)
        param = list(branch.parameters())
        optimizer, lr = get_optim(param, args, mode='preTrain', epoch=epoch)
        for batch_idx, (inputs, targets) in enumerate(required_train_loader):
            if cf.use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            optimizer.zero_grad()
            inputs, targets = Variable(inputs), Variable(targets).float()
            outputs = branch(inputs)
            matrix = np.vstack(((np.ones(np.shape(clusting_result))-clusting_result), clusting_result))
            matrix = torch.from_numpy(matrix.transpose().astype(np.float32))
            if cf.use_cuda:
                matrix = matrix.cuda()
            outputs = outputs.mm(matrix)
            targets = targets.mm(matrix)
            loss = pred_loss(outputs,targets)
            loss.backward()  # Backward Propagation
            optimizer.step() # Optimizer update
            sys.stdout.write('\r')
            sys.stdout.write('Train Branch with Epoch [%3d/%3d] Iter [%3d/%3d]\t\t Loss: %.4f'
                         %(epoch+1, args.num_epochs_train, batch_idx+1, (pretrainloader.dataset.train_data.shape[0]//args.pretrain_batch_size)+1,
                           loss.item()))
            sys.stdout.flush()
    return branch
Exemplo n.º 5
0
def deploy(config):

    # Setup deploy directory and get G architecture
    deploy = deploy_utils.setup_deploy(config)

    for key, value in deploy.items():
        print(key, value)
    input('Press return to deploy this configuration')

    # Select gpu for model evaluation
    device = torch.device(config['gpu'])

    # Get model
    model = utils.get_model(deploy, device)

    # Get optimizer
    optim = utils.get_optim(deploy, model.parameters())

    # Load model checkpoint
    checkpoint = torch.load(deploy['weights'])

    # Load state dicts
    model.load_state_dict(checkpoint['state_dict'])
    optim.load_state_dict(checkpoint['optimizer'])

    # Put model in evaluation mode
    model.eval()

    progress = 0
    # Sample model and save outputs
    print(deploy['num_samples'])
    for iter in range(deploy['num_samples']):
        # Generate random vector
        z = torch.randn(config['batch_size'], deploy['z_dim'], device=device)

        # Get sample from model
        sample = model(z).view(-1, 1, deploy['imsize'], deploy['imsize'])
        sample = ((sample * 0.5) + 0.5) * 10

        # Save the sample
        deploy_utils.save_deploy_sample(sample, iter, deploy['deploy_path'])

        # Update console
        print('Sample number {} | Total {}'.format(progress,
                                                   deploy['num_samples']))
        progress += 1
        del z  # just to make sure
    print('Deploy completed with {} samples'.format(deploy['num_samples']))

    if deploy['hist']:
        print('Generating deploy histogram')
        deploy_utils.adc_and_histogram(deploy)
    if deploy['hamming']:
        print('Generating Hamming distance histogram')
        hamming.HD(deploy)
Exemplo n.º 6
0
def enhance_expert(Expert, Superclass, c, mode='clone'):
    if mode == 'clone':
        print(
            '\nThe new expert model is activate and waiting for another class added to build'
        )
    elif mode == 'merge':
        for epoch in range(args.num_epochs_train):
            required_train_loader = get_dataLoder(args,
                                                  classes=Superclass[c],
                                                  mode='Train',
                                                  encoded=False,
                                                  one_hot=True)
            if epoch == 0:
                num = len(Superclass[c])
                Expert[c] = prepared_model(num, c)
                if cf.use_cuda:
                    Expert[c].cuda()
                    cudnn.benchmark = True
            param = list(Expert[c].parameters())
            optimizer, lr = get_optim(param,
                                      args,
                                      mode='preTrain',
                                      epoch=epoch)
            for batch_idx, (inputs,
                            targets) in enumerate(required_train_loader):
                if batch_idx >= args.num_test:
                    break
                if cf.use_cuda:
                    inputs, targets = inputs.cuda(), targets.cuda(
                    )  # GPU setting
                optimizer.zero_grad()
                inputs, targets = Variable(inputs), Variable(targets).long()
                outputs = Expert[c](inputs)  # Forward Propagation
                loss = pred_loss(outputs, targets)
                loss.backward()  # Backward Propagation
                optimizer.step()  # Optimizer update
                _, predicted = torch.max(outputs.data, 1)
                num_ins = targets.size(0)
                correct = predicted.eq((torch.max(targets.data,
                                                  1)[1])).cpu().sum()
                acc = 100. * correct.item() / num_ins
                sys.stdout.write('\r')
                sys.stdout.write(
                    'Train expert model with Epoch [%3d/%3d] Iter [%3d]\t\t Loss: %.4f Accuracy: %.3f%%'
                    % (epoch + 1, args.num_epochs_train, batch_idx + 1,
                       loss.item(), acc))
                sys.stdout.flush()
        save_model(Expert[c], c)
    else:
        print('\nmode error')
    return Expert
Exemplo n.º 7
0
def train_autoencoder(Autoencoder, Superclass, Old_superclass):
    # ================== used to train the new encoder ==================
    print('\n=========== refesh the autoencoders ===========')
    for dict in Superclass:
        refresh = 'false'
        if dict not in Old_superclass.keys():
            refresh = 'true'
        elif Superclass[dict] != Old_superclass[dict]:
            refresh = 'true'
        if refresh == 'true':
            print('\nrefeshing the autoencoder:' + dict)
            Autoencoder[dict] = autoencoder(args)
            if cf.use_cuda:
                Autoencoder[dict].cuda()
                cudnn.benchmark = True
            for epoch in range(args.num_epochs_train):
                Autoencoder[dict].train()
                required_train_loader = get_dataLoder(args,
                                                      classes=Superclass[dict],
                                                      mode='Train',
                                                      encoded=True,
                                                      one_hot=False)
                param = list(Autoencoder[dict].parameters())
                optimizer, lr = get_optim(param,
                                          args,
                                          mode='preTrain',
                                          epoch=epoch)
                for batch_idx, (inputs,
                                targets) in enumerate(required_train_loader):
                    if batch_idx >= args.num_test:
                        break
                    if cf.use_cuda:
                        inputs = inputs.cuda()  # GPU settings
                    optimizer.zero_grad()
                    inputs = Variable(inputs)
                    reconstructions, _ = Autoencoder[dict](inputs)
                    loss = cross_entropy(reconstructions, inputs)
                    loss.backward()  # Backward Propagation
                    optimizer.step()  # Optimizer update
                    sys.stdout.write('\r')
                    sys.stdout.write(
                        'Refreshing autoencoder:' + dict +
                        ' with Epoch [%3d/%3d] Iter [%3d]\t\t Loss: %.4f' %
                        (epoch + 1, args.num_epochs_train, batch_idx + 1,
                         loss.item()))
                    sys.stdout.flush()
            print('\nautoencoder model:' + str(dict) +
                  ' is constrcuted with final loss:' + str(loss.item()))
    return Autoencoder
Exemplo n.º 8
0
def gan(model, config):
    '''
        GAN setup function
    '''
    # Get G and D kwargs based on command line inputs
    g_kwargs, d_kwargs = get_gan_kwargs(config)

    # Set up models on GPU
    G = model.Generator(**g_kwargs).to(config['gpu'])
    D = model.Discriminator(**d_kwargs).to(config['gpu'])

    print(G)
    print(D)
    input('Press any key to launch')

    # Initialize model weights
    G.weights_init()
    D.weights_init()

    # Set up model optimizer functions
    model_params = {'g_params': G.parameters(), 'd_params': D.parameters()}
    G_optim, D_optim = utils.get_optim(config, model_params)

    # Set up loss function
    if 'bce' in config['loss_fn']:
        loss_fn = nn.BCELoss().to(config['gpu'])
    else:
        raise Exception("No GAN loss function selected ... aborting")

    # Set up training function
    train_fn = train_fns.GAN_train_fn(G,
                                      D,
                                      G_optim,
                                      D_optim,
                                      loss_fn,
                                      config,
                                      G_D=None)
    return {
        'G': G,
        'G_optim': G_optim,
        'D': D,
        'D_optim': D_optim,
        'train_fn': train_fn
    }
Exemplo n.º 9
0
def fine_tune(epoch):
    net.share.train()
    net.coarse.train()
    for i in range (args.num_superclasses):
        net.fines[i].train()

    param = list(net.share.parameters()) + list(net.coarse.parameters())
    for k in range(args.num_superclasses):
        param += list(net.fines[k].parameters())
    optimizer, lr = get_optim(param, args, mode='fineTune', epoch=epoch)

    print('\n==> fine-tune Epoch #%d, LR=%.4f' % (epoch, lr))
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        if cf.use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets).long()

        outputs, coarse_outputs = net(inputs, return_coarse=True)

        tloss = pred_loss(outputs, targets)
        closs = consistency_loss(coarse_outputs, t_k, weight=args.weight_consistency)
        loss = tloss + closs
        loss.backward()  # Backward Propagation
        optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        num_ins = targets.size(0)
        correct = predicted.eq(targets.data).cpu().sum()
        acc = 100.*correct.item()/num_ins

        sys.stdout.write('\r')
        sys.stdout.write('Finetune Epoch [%3d/%3d] Iter [%3d/%3d]\t\t tloss: %.4f closs: %.4f Loss: %.4f Accuracy: %.3f%%'
                         %(epoch, args.num_epochs_train, batch_idx+1, (trainloader.dataset.train_data.shape[0]//args.train_batch_size)+1,
                           tloss.item(), closs.item(), loss.item(), acc))
        sys.stdout.flush()
Exemplo n.º 10
0
def ae(model, config):
    '''
        AutoEncoder setup function
    '''
    # Get model kwargs
    ae_kwargs, config = get_ae_kwargs(config)

    # Set up model on GPU
    if config['model'] == 'ae':
        AE = model.AutoEncoder(**ae_kwargs).to(config['gpu'])
    else:
        AE = model.ConvAutoEncoder(**ae_kwargs).to(config['gpu'])

    print(AE)
    input('Press any key to launch')

    # Set up model optimizer function
    model_params = {'ae_params': AE.parameters()}
    AE_optim = utils.get_optim(config, model_params)

    # Set up loss function
    if 'mse' in config['loss_fn']:
        loss_fn = nn.MSELoss().to(config['gpu'])
    elif 'bce' in config['loss_fn']:
        loss_fn = nn.BCELoss().to(config['gpu'])
    else:
        raise Exception("No AutoEncoder loss function selected!")

    # Set up training function
    if config['model'] == 'ae':
        train_fn = train_fns.AE_train_fn(AE, AE_optim, loss_fn, config)
    else:
        train_fn = train_fns.Conv_AE_train_fn(AE, AE_optim, loss_fn, config)

    # Return model, optimizer, and model training function
    return {'AE': AE, 'AE_optim': AE_optim, 'train_fn': train_fn}
Exemplo n.º 11
0
def train_test_autoencoder(newclasses, Autoencoder):
    # ================== used to train the new encoder ==================
    Autoencoder[str(newclasses)] = autoencoder(args)
    if cf.use_cuda:
        Autoencoder[str(newclasses)].cuda()
        cudnn.benchmark = True
    for epoch in range(args.num_epochs_train):
        Autoencoder[str(newclasses)].train()
        required_train_loader = get_dataLoder(args,
                                              classes=[newclasses],
                                              mode='Train',
                                              encoded=True,
                                              one_hot=True)
        param = list(Autoencoder[str(newclasses)].parameters())
        optimizer, lr = get_optim(param, args, mode='preTrain', epoch=epoch)
        print('\n==> Epoch #%d, LR=%.4f' % (epoch + 1, lr))
        for batch_idx, (inputs, targets) in enumerate(required_train_loader):
            if batch_idx >= args.num_test:
                break
            if cf.use_cuda:
                inputs = inputs.cuda()  # GPU settings
            optimizer.zero_grad()
            inputs = Variable(inputs)
            reconstructions, _ = Autoencoder[str(newclasses)](inputs)
            loss = cross_entropy(reconstructions, inputs)
            loss.backward()  # Backward Propagation
            optimizer.step()  # Optimizer update
            sys.stdout.write('\r')
            sys.stdout.write(
                'Train autoencoder:' + str(newclasses) +
                ' with Epoch [%3d/%3d] Iter [%3d]\t\t Loss: %.4f' %
                (epoch + 1, args.num_epochs_train, batch_idx + 1, loss.item()))
            sys.stdout.flush()
    # =============== used to classify it and nut it in a proper superclass ==============
    if Autoencoder:
        Loss = {}
        Rel = {}
        print('\ntesting the new data in previous autoencoders')
        for dict in Autoencoder:
            Loss[dict] = 0
            required_valid_loader = get_dataLoder(args,
                                                  classes=[int(dict)],
                                                  mode='Valid',
                                                  encoded=True,
                                                  one_hot=True)
            for batch_idx, (inputs,
                            targets) in enumerate(required_valid_loader):
                if batch_idx >= args.num_test:
                    break
                if cf.use_cuda:
                    inputs = inputs.cuda()  # GPU settings
                inputs = Variable(inputs)
                reconstructions, _ = Autoencoder[dict](inputs)
                loss = cross_entropy(reconstructions, inputs)
                Loss[dict] += loss.data.cpu().numpy(
                ) if cf.use_cuda else loss.data.numpy()
        print('\nAutoencoder:' + str(newclasses) +
              ' is been delated and wait for update for every ten classes')
        Autoencoder.pop(
            str(newclasses), '\nthe class:' + str(newclasses) +
            ' is not been delated as the dict not exist')
        highest = 0
        test_result = ''
        for dict in Loss:
            Rel[dict] = 1 - abs(
                (Loss[dict] - Loss[str(newclasses)]) / Loss[str(newclasses)])
            if Rel[dict] >= highest and Rel[
                    dict] >= args.rel_th and dict != str(newclasses):
                highest = Rel[dict]
                test_result = dict
                print('\nnewclass:' + str(newclasses) +
                      ' is add to superclass with class:' + dict)
        print('\nClass rel:', Rel, ' and Loss:', Loss)
        return Autoencoder, test_result
    else:
        return Autoencoder, _
Exemplo n.º 12
0
    batch_size = 64
    latent_dim = 100
    d_updates = 5

    dataloader = utils.get_dataloader(batch_size)
    device = utils.get_device()
    step_per_epoch = np.ceil(dataloader.dataset.__len__() / batch_size)
    sample_dir = './samples'
    checkpoint_dir = './checkpoints'

    utils.makedirs(sample_dir, checkpoint_dir)

    G = Generator(latent_dim=latent_dim).to(device)
    D = Discriminator().to(device)

    g_optim = utils.get_optim(G, 0.00005)
    d_optim = utils.get_optim(D, 0.00005)

    g_log = []
    d_log = []

    criterion = nn.BCELoss()

    fix_z = torch.randn(batch_size, latent_dim).to(device)
    for epoch_i in range(1, epochs + 1):
        for step_i, (real_img, _) in enumerate(dataloader):

            real_labels = torch.ones(batch_size).to(device)
            fake_labels = torch.zeros(batch_size).to(device)

            # Train D
Exemplo n.º 13
0
def train(model, triples, entities, un_ents, un_rels, test_pairs):
    logging.info("---------------Start Training---------------")

    ht_1, ht_2 = get_r_hts(triples, un_rels)
    rel_seeds = relation_seeds({}, ht_1, ht_2, un_rels)

    current_lr = config.learning_rate
    optimizer = get_optim(model, current_lr)
    if config.init_checkpoint:
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(os.path.join(config.save_path, "checkpoint"))
        init_step = checkpoint["step"] + 1
        model.load_state_dict(checkpoint["model_state_dict"])
        if config.use_old_optimizer:
            current_lr = checkpoint["current_lr"]
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    else:
        init_step = 1

    training_logs = []
    train_iterator = train_data_iterator(entities,
                                         new_triples(triples, rel_seeds, {}))
    # Training Loop
    for step in range(init_step, config.max_step):
        log = train_step(model, optimizer, next(train_iterator))
        training_logs.append(log)

        # log
        if step % config.log_step == 0:
            metrics = {}
            for metric in training_logs[0].keys():
                metrics[metric] = sum([log[metric] for log in training_logs
                                       ]) / len(training_logs)
            log_metrics("Training average", step, metrics)
            training_logs.clear()

        # warm up
        if step % config.warm_up_step == 0:
            current_lr *= 0.1
            logging.info("Change learning_rate to %f at step %d" %
                         (current_lr, step))
            optimizer = get_optim(model, current_lr)

        if step % config.update_step == 0:
            logging.info("Align entities and relations, swap parameters")
            seeds, align_e_1, align_e_2 = entity_seeds(model, un_ents)
            rel_seeds = relation_seeds(seeds, ht_1, ht_2, un_rels)
            new_entities = (entities[0] + align_e_2, entities[1] + align_e_1)
            train_iterator = train_data_iterator(
                new_entities, new_triples(triples, rel_seeds, seeds))
            save_variable_list = {
                "step": step,
                "current_lr": current_lr,
            }
            save_model(model, optimizer, save_variable_list)

    logging.info("---------------Test on test dataset---------------")
    metrics = test_step(model, test_pairs, un_ents)
    log_metrics("Test", config.max_step, metrics)

    logging.info("---------------Taining End---------------")
Exemplo n.º 14
0
def train(model, triples, ent_num):
    logging.info("Start Training...")
    logging.info("batch_size = %d" % config.batch_size)
    logging.info("dim = %d" % config.ent_dim)
    logging.info("gamma = %f" % config.gamma)

    current_lr = config.learning_rate
    train_triples, valid_triples, test_triples = triples
    all_true_triples = train_triples + valid_triples + test_triples
    rtp = rel_type(train_triples)

    optimizer = get_optim("Adam", model, current_lr)
    train_iterator = train_data_iterator(train_triples, ent_num)

    if config.init_checkpoint:
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(os.path.join(config.save_path, "checkpoint"))
        init_step = checkpoint["step"] + 1
        model.load_state_dict(checkpoint["model_state_dict"])
        if config.use_old_optimizer:
            current_lr = checkpoint["current_lr"]
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    else:
        init_step = 1

    max_hit1 = 0.0
    max_mrr = 0.0
    training_logs = []
    # Training Loop
    for step in range(init_step, config.max_step):
        log = train_step(model, optimizer, next(train_iterator))
        training_logs.append(log)

        # log
        if step % config.log_step == 0:
            metrics = {}
            for metric in training_logs[0].keys():
                metrics[metric] = sum([log[metric] for log in training_logs
                                       ]) / len(training_logs)
            log_metrics("Training average", step, metrics)
            training_logs = []

        # valid
        if step % config.valid_step == 0:
            logging.info(
                "---------------Evaluating on Valid Dataset---------------")
            metrics = test_step(model, valid_triples, all_true_triples,
                                ent_num, rtp)
            metrics, metrics1, metrics2, metrics3, metrics4, metrics5, metrics6, metrics7, metrics8 = metrics
            logging.info("----------------Overall Results----------------")
            log_metrics("Valid", step, metrics)
            logging.info("-----------Prediction Head... 1-1 -------------")
            log_metrics("Valid", step, metrics1)
            logging.info("-----------Prediction Head... 1-M -------------")
            log_metrics("Valid", step, metrics2)
            logging.info("-----------Prediction Head... M-1 -------------")
            log_metrics("Valid", step, metrics3)
            logging.info("-----------Prediction Head... M-M -------------")
            log_metrics("Valid", step, metrics4)
            logging.info("-----------Prediction Tail... 1-1 -------------")
            log_metrics("Valid", step, metrics5)
            logging.info("-----------Prediction Tail... 1-M -------------")
            log_metrics("Valid", step, metrics6)
            logging.info("-----------Prediction Tail... M-1 -------------")
            log_metrics("Valid", step, metrics7)
            logging.info("-----------Prediction Tail... M-M -------------")
            log_metrics("Valid", step, metrics8)
            if metrics["HITS@1"] >= max_hit1 or metrics["MRR"] >= max_mrr:
                if metrics["HITS@1"] > max_hit1:
                    max_hit1 = metrics["HITS@1"]
                if metrics["MRR"] > max_mrr:
                    max_mrr = metrics["MRR"]
                save_variable_list = {
                    "step": step,
                    "current_lr": current_lr,
                }
                save_model(model, optimizer, save_variable_list)
            elif current_lr > 0.0000011:
                current_lr *= 0.1
                logging.info("Change learning_rate to %f at step %d" %
                             (current_lr, step))
                optimizer = get_optim("Adam", model, current_lr)
            else:
                logging.info(
                    "-------------------Training End-------------------")
                break
    # best state
    checkpoint = torch.load(os.path.join(config.save_path, "checkpoint"))
    model.load_state_dict(checkpoint["model_state_dict"])
    step = checkpoint["step"]
    logging.info(
        "-----------------Evaluating on Test Dataset-------------------")
    metrics = test_step(model, test_triples, all_true_triples, ent_num, rtp)
    metrics, metrics1, metrics2, metrics3, metrics4, metrics5, metrics6, metrics7, metrics8 = metrics
    logging.info("----------------Overall Results----------------")
    log_metrics("Test", step, metrics)
    logging.info("-----------Prediction Head... 1-1 -------------")
    log_metrics("Test", step, metrics1)
    logging.info("-----------Prediction Head... 1-M -------------")
    log_metrics("Test", step, metrics2)
    logging.info("-----------Prediction Head... M-1 -------------")
    log_metrics("Test", step, metrics3)
    logging.info("-----------Prediction Head... M-M -------------")
    log_metrics("Test", step, metrics4)
    logging.info("-----------Prediction Tail... 1-1 -------------")
    log_metrics("Test", step, metrics5)
    logging.info("-----------Prediction Tail... 1-M -------------")
    log_metrics("Test", step, metrics6)
    logging.info("-----------Prediction Tail... M-1 -------------")
    log_metrics("Test", step, metrics7)
    logging.info("-----------Prediction Tail... M-M -------------")
    log_metrics("Test", step, metrics8)
Exemplo n.º 15
0
    if params.early_stopping['monitor_metric'] == False:
        logging.info("- early stopping disabled.")
    else:
        logging.info("Early stopping enabled.")

    # set calculate gradients for all or separate layers
    for layer in model.parameters():
        layer.requires_grad = True
    for i in model.fc.parameters():
        i.requires_grad = True

    # split model to differnet groups and assign lr for each
    optimizer = utils.get_optim(
        model=model,
        idx_to_split_model=[5, 8],
        optimizer=torch.optim.Adam,
        lrs=params.learning_rate,
        wd=params.wd,
    )

    # fetch loss function and metrics
    loss_fn = torch.nn.CrossEntropyLoss()

    # select scheduler
    if params.scheduler == "one_cycle":
        scheduler = utils.OneCycle(
            params.num_epochs,
            optimizer,
            div_factor=params.div_factor,
            pct_start=params.pct_start,
            dl_len=len(train_dl),
Exemplo n.º 16
0
def search(out_dir, chkpt_path, w_train_loader, a_train_loader, model, arch,
           writer, logger, device, config):
    valid_loader = a_train_loader

    w_optim = utils.get_optim(model.weights(), config.w_optim)
    a_optim = utils.get_optim(model.alphas(), config.a_optim)

    init_epoch = -1

    if chkpt_path is not None:
        logger.info("Resuming from checkpoint: %s" % chkpt_path)
        checkpoint = torch.load(chkpt_path)
        model.load_state_dict(checkpoint['model'])
        NASModule.nasmod_load_state_dict(checkpoint['arch'])
        w_optim.load_state_dict(checkpoint['w_optim'])
        a_optim.load_state_dict(checkpoint['a_optim'])
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        init_epoch = checkpoint['epoch']
    else:
        logger.info("Starting new training run")

    architect = arch(config, model)

    # warmup training loop
    logger.info('begin warmup training')
    try:
        if config.warmup_epochs > 0:
            warmup_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                w_optim, config.warmup_epochs, eta_min=config.w_optim.lr_min)
            last_epoch = 0
        else:
            last_epoch = -1

        tot_epochs = config.warmup_epochs
        for epoch in itertools.count(init_epoch + 1):
            if epoch == tot_epochs: break
            lr = warmup_lr_scheduler.get_lr()[0]
            # training
            train(w_train_loader, None, model, writer, logger, architect,
                  w_optim, a_optim, lr, epoch, tot_epochs, device, config)
            # validation
            cur_step = (epoch + 1) * len(w_train_loader)
            top1 = validate(valid_loader, model, writer, logger, epoch,
                            tot_epochs, cur_step, device, config)
            warmup_lr_scheduler.step()
            print("")
    except KeyboardInterrupt:
        print('skipped')

    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        w_optim,
        config.epochs,
        eta_min=config.w_optim.lr_min,
        last_epoch=last_epoch)

    save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler, init_epoch,
                    logger)
    save_genotype(out_dir, model.genotype(), init_epoch, logger)

    # training loop
    logger.info('begin w/a training')
    best_top1 = 0.
    tot_epochs = config.epochs
    for epoch in itertools.count(init_epoch + 1):
        if epoch == tot_epochs: break
        lr = lr_scheduler.get_lr()[0]
        model.print_alphas(logger)
        # training
        train(w_train_loader, a_train_loader, model, writer, logger, architect,
              w_optim, a_optim, lr, epoch, tot_epochs, device, config)
        # validation
        cur_step = (epoch + 1) * len(w_train_loader)
        top1 = validate(valid_loader, model, writer, logger, epoch, tot_epochs,
                        cur_step, device, config)
        # genotype
        genotype = model.genotype()
        save_genotype(out_dir, genotype, epoch, logger)
        # genotype as image
        if config.plot:
            for i, dag in enumerate(model.dags()):
                plot_path = os.path.join(config.plot_path,
                                         "EP{:02d}".format(epoch + 1))
                caption = "Epoch {} - DAG {}".format(epoch + 1, i)
                plot(genotype.dag[i], dag, plot_path + "-dag_{}".format(i),
                     caption)
        if best_top1 < top1:
            best_top1 = top1
            best_genotype = genotype
        if config.save_freq != 0 and epoch % config.save_freq == 0:
            save_checkpoint(out_dir, model, w_optim, a_optim, lr_scheduler,
                            epoch, logger)
        lr_scheduler.step()
        print("")

    logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
    logger.info("Best Genotype = {}".format(best_genotype))
    tprof.stat_acc('model_' + NASModule.get_device()[0])
    gt.to_file(best_genotype, os.path.join(out_dir, 'best.gt'))
Exemplo n.º 17
0
def pretrain_clustering(epoch, mode, cluster_result=None):
    net.share.train()
    net.croase.train()

    train_loss = 0
    optimizer_share, lr = get_optim(net.share,
                                    args,
                                    mode='preTrain',
                                    epoch=epoch)
    optimizer_croase, lr = get_optim(net.croase,
                                     args,
                                     mode='preTrain',
                                     epoch=epoch)

    if mode == 'clustering':
        Data = enumerate(trainloader)
        print('\ntrain data-loader activated')
    else:
        print(
            '---------------- Warning! no mode is activated ----------------\n'
        )
    print('=> pre-train Epoch #%d, LR=%.4f' % (epoch, lr))
    for batch_idx, (inputs, targets) in Data:
        if batch_idx >= args.num_test:
            break
        if cf.use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()  # GPU settings
        optimizer_share.zero_grad()
        optimizer_croase.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)

        outputs = net.croase.independent(
            net.share.encoder(inputs))  # Forward Propagation
        if batch_idx == 0:
            total_outputs = outputs
            total_targets = targets
        else:
            total_outputs = torch.cat((total_outputs, outputs), 0)
            total_targets = torch.cat((total_targets, targets), 0)

        loss = pred_loss(outputs, targets)

        loss.backward()  # Backward Propagation
        optimizer_share.step()  # Optimizer update
        optimizer_croase.step()  # Optimizer update

        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        _, targets = torch.max(targets.data, 1)
        num_ins = targets.size(0)
        correct = predicted.eq(targets.data).cpu().sum()

        sys.stdout.write('\r')
        sys.stdout.write(
            'Pre-train Epoch [%3d/%3d] Iter [%3d/%3d]\t\t Loss: %.4f Accuracy: %.3f%%'
            %
            (epoch, args.num_epochs_pretrain, batch_idx + 1,
             (trainloader.dataset.train_data.shape[0] // args.train_batch_size)
             + 1, loss.item(), 100. * correct.item() / num_ins))
        sys.stdout.flush()

    print('\n=> valid epoch begining for clustering')
    if mode == 'clustering':
        clustering_data = enumerate(validloader)
        print('valid data-loader activated')
        for batch_idx, (inputs, targets) in clustering_data:
            if batch_idx >= args.num_test:
                break
            if cf.use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda()  # GPU settings
            optimizer_share.zero_grad()
            inputs, targets = Variable(inputs), Variable(targets)

            outputs = net.croase.independent(
                net.share.encoder(inputs))  # Forward Propagation
            if batch_idx == 0:
                total_outputs = outputs
                total_targets = targets
            else:
                total_outputs = torch.cat((total_outputs, outputs), 0)
                total_targets = torch.cat((total_targets, targets), 0)
            loss = pred_loss(outputs, targets)

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            _, targets = torch.max(targets.data, 1)
            num_ins = targets.size(0)
            correct = predicted.eq(targets.data).cpu().sum()
            acc = 100. * correct.item() / num_ins
            sys.stdout.write('\r')
            sys.stdout.write(
                'valid epoch begining [%3d/%3d] Iter [%3d/%3d]\t\t Accuracy: %.3f%%'
                % (epoch, args.num_epochs_pretrain, batch_idx + 1,
                   (trainloader.dataset.train_data.shape[0] //
                    args.train_batch_size) + 1, acc))
            sys.stdout.flush()

        print('\nSaving model...\t\t\tTop1 = %.2f%%' % (acc))
        share_params = net.share.state_dict()
        croase_params = net.croase.state_dict()
        save_point = cf.model_dir + args.dataset
        if not os.path.isdir(save_point):
            os.mkdir(save_point)
        torch.save(share_params, save_point + '/share_params.pkl')
        torch.save(croase_params, save_point + '/croase_params.pkl')
    return total_outputs, total_targets
Exemplo n.º 18
0
def learn_and_clustering(args, L, epoch, classes, test = True, cluster = True, save = False):
    required_train_loader = get_dataLoder(args, classes= classes, mode='Train', one_hot=True)
    # L = leaf(args,np.shape( classes)[0])
    param = list(L.parameters())
    optimizer, lr = get_optim(param, args, mode='preTrain', epoch=epoch)
    print('\n==> Epoch %d, LR=%.4f' % (epoch+1, lr))
    best_acc=0
    required_data = []
    required_targets = []
    for batch_idx, (inputs, targets) in enumerate(required_train_loader):
        if batch_idx>=args.num_test:
            break
        # targets = targets[:, sorted(list({}.fromkeys((torch.max(targets.data, 1)[1]).numpy()).keys()))]
        if cf.use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda() # GPU setting
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets).long()
        outputs = L(inputs) # Forward Propagation
        loss = pred_loss(outputs,targets)
        loss.backward()  # Backward Propagation
        optimizer.step() # Optimizer update

        _, predicted = torch.max(outputs.data, 1)
        num_ins = targets.size(0)
        correct = predicted.eq((torch.max(targets.data, 1)[1])).cpu().sum()
        acc=100.*correct.item()/num_ins
        sys.stdout.write('\r')
        sys.stdout.write('Train Epoch [%3d/%3d] Iter [%3d/%3d]\t\t Loss: %.4f Accuracy: %.3f%%'
                         %(epoch+1, args.num_epochs_train, batch_idx+1, (pretrainloader.dataset.train_data.shape[0]//args.pretrain_batch_size)+1,
                           loss.item(), acc))
        sys.stdout.flush()
        #========================= saving the model ============================
        if epoch+1 == args.num_epochs_train and acc>best_acc and save:
            print('\nSaving the best leaf model...\t\t\tTop1 = %.2f%%' % (acc))
            save_point = cf.var_dir + args.dataset
            if not os.path.isdir(save_point):
                os.mkdir(save_point)
            torch.save(L.state_dict(), save_point + '/L0.pkl')
            best_acc=acc
    #============================ valid training result ==================================
    if epoch+1 == args.num_epochs_train and test:
        required_valid_loader = get_dataLoder(args, classes= classes, mode='Valid', one_hot = True)
        num_ins = 0
        correct = 0
        for batch_idx, (inputs, targets) in enumerate(required_valid_loader):
            if cf.use_cuda:
                inputs, targets = inputs.cuda(), targets.cuda() # GPU settings
            inputs, targets = Variable(inputs), Variable(targets).long()
            outputs = L(inputs)

            #======================= prepare data for clustering ============================
            if cluster:
                batch_required_data = outputs
                batch_required_targets = targets
                batch_required_data = batch_required_data.data.cpu().numpy() if cf.use_cuda else batch_required_data.data.numpy()
                batch_required_targets = batch_required_targets.data.cpu().numpy() if cf.use_cuda else batch_required_targets.data.numpy()
                required_data = stack_or_create(required_data, batch_required_data, axis=0)
                required_targets = stack_or_create(required_targets, batch_required_targets, axis=0)
            targets = torch.argmax(targets,1)
            _, predicted = torch.max(outputs.data, 1)
            num_ins += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum().item()
        #============================ clustering ==================================
        if cluster:
            print('\n==> Doing the spectural clusturing')
            required_data = np.argmax(required_data, 1)
            required_targets = np.argmax(required_targets,1)
            F = function.confusion_matrix(required_data, required_targets)
            D = (1/2)*((np.identity(np.shape(classes)[0])-F)+np.transpose(np.identity(np.shape(classes)[0])-F))
            cluster_result = function.spectral_clustering(D, K=args.num_superclasses, gamma=10)
        acc = 100.*correct/num_ins
        print("\nValidation Epoch %d\t\tAccuracy: %.2f%%" % (epoch+1, acc))
        if cluster:
            return L, cluster_result, acc
        else:
            return L, _, acc
    else:
        return L, _ , _
Exemplo n.º 19
0
def train(model, triples, ent_num):
    logging.info("Start Training...")
    logging.info("batch_size = %d" % config.batch_size)
    logging.info("dim = %d" % config.ent_dim)
    logging.info("gamma = %f" % config.gamma)

    current_lr = config.learning_rate
    train_triples, valid_triples, test_triples, symmetry_test, inversion_test, composition_test, others_test = triples
    all_true_triples = train_triples + valid_triples + test_triples
    r_tp = rel_type(train_triples)

    optimizer = get_optim("Adam", model, current_lr)

    if config.init_checkpoint:
        logging.info("Loading checkpoint...")
        checkpoint = torch.load(os.path.join(config.save_path, "checkpoint"),
                                map_location=torch.device("cuda:0"))
        init_step = checkpoint["step"] + 1
        model.load_state_dict(checkpoint["model_state_dict"])
        if config.use_old_optimizer:
            current_lr = checkpoint["current_lr"]
            optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    else:
        init_step = 1

    true_all_heads, true_all_tails = get_true_ents(all_true_triples)
    train_iterator = train_data_iterator(train_triples, ent_num)
    test_data_list = test_data_sets(valid_triples, true_all_heads,
                                    true_all_tails, ent_num, r_tp)

    max_mrr = 0.0
    training_logs = []
    modes = ["Prediction Head", "Prediction Tail"]
    rtps = ["1-1", "1-M", "M-1", "M-M"]
    # Training Loop
    for step in range(init_step, config.max_step + 1):
        log = train_step(model, optimizer, next(train_iterator))
        training_logs.append(log)

        # log
        if step % config.log_step == 0:
            metrics = {}
            for metric in training_logs[0].keys():
                metrics[metric] = sum([log[metric] for log in training_logs
                                       ]) / len(training_logs)
            log_metrics("Training", step, metrics)
            training_logs.clear()

        # valid
        if step % config.valid_step == 0:
            logging.info("-" * 10 + "Evaluating on Valid Dataset" + "-" * 10)
            metrics = test_step(model, test_data_list, True)
            log_metrics("Valid", step, metrics[0])
            cnt_mode_rtp = 1
            for mode in modes:
                for rtp in rtps:
                    logging.info("-" * 10 + mode + "..." + rtp + "-" * 10)
                    log_metrics("Valid", step, metrics[cnt_mode_rtp])
                    cnt_mode_rtp += 1
            if metrics[0]["MRR"] >= max_mrr:
                max_mrr = metrics[0]["MRR"]
                save_variable_list = {
                    "step": step,
                    "current_lr": current_lr,
                }
                save_model(model, optimizer, save_variable_list)
            if step / config.max_step in [0.2, 0.5, 0.8]:
                current_lr *= 0.1
                logging.info("Change learning_rate to %f at step %d" %
                             (current_lr, step))
                optimizer = get_optim("Adam", model, current_lr)

    # load best state
    checkpoint = torch.load(os.path.join(config.save_path, "checkpoint"))
    model.load_state_dict(checkpoint["model_state_dict"])
    step = checkpoint["step"]

    # relation patterns
    test_datasets = [
        symmetry_test, inversion_test, composition_test, others_test
    ]
    test_datasets_str = ["Symmetry", "Inversion", "Composition", "Other"]
    for i in range(len(test_datasets)):
        dataset = test_datasets[i]
        dataset_str = test_datasets_str[i]
        if len(dataset) == 0:
            continue
        test_data_list = test_data_sets(dataset, true_all_heads,
                                        true_all_tails, ent_num, r_tp)
        logging.info("-" * 10 + "Evaluating on " + dataset_str + " Dataset" +
                     "-" * 10)
        metrics = test_step(model, test_data_list)
        log_metrics("Valid", step, metrics)

    # finally test
    test_data_list = test_data_sets(test_triples, true_all_heads,
                                    true_all_tails, ent_num, r_tp)
    logging.info("----------Evaluating on Test Dataset----------")
    metrics = test_step(model, test_data_list, True)
    log_metrics("Test", step, metrics[0])
    cnt_mode_rtp = 1
    for mode in modes:
        for rtp in rtps:
            logging.info("-" * 10 + mode + "..." + rtp + "-" * 10)
            log_metrics("Test", step, metrics[cnt_mode_rtp])
            cnt_mode_rtp += 1
Exemplo n.º 20
0
def fine_tune(epoch, cluster_result=None, u_kj=None):
    save_point = cf.model_dir + args.dataset

    net.share.train()
    net.croase.train()
    for i in range(args.num_superclass):
        net.fines[i].train()

    train_loss = 0
    optimizer_share, lr = get_optim(net.share,
                                    args,
                                    mode='preTrain',
                                    epoch=epoch)
    optimizer_croase, lr = get_optim(net.croase,
                                     args,
                                     mode='preTrain',
                                     epoch=epoch)
    optimizer_fine = {}
    for i in range(args.num_superclass):
        optimizer_fine[i], lr = get_optim(net.fines[i],
                                          args,
                                          mode='preTrain',
                                          epoch=epoch)

    if epoch == 1:
        print('\nprevious model activated')
        net.load_state_dict(torch.load(save_point + '/over_all_model.pkl'))
        for i in range(args.num_superclass):
            net.fines[i].load_state_dict(
                torch.load(save_point + '/fine' + str(i) + '.pkl'))
    print('\ntrain data-loader activated')
    Data = enumerate(trainloader)
    for batch_idx, (inputs, targets) in Data:
        if batch_idx >= args.num_test:
            break
        if cf.use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()  # GPU settings
        optimizer_share.zero_grad()
        optimizer_croase.zero_grad()
        for i in range(args.num_superclass):
            optimizer_fine[i].zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        share = net.share.encoder(inputs)
        outputs = net.croase.independent(share)  # Forward Propagation
        fine_out = {}
        for i in range(args.num_superclass):
            fine_out[i] = net.fines[i].independent(
                share)  # output of each fine layers
        # ==================== prepare B_o_ik =======================
        B_o_ik = np.zeros((np.shape(outputs)[0], args.num_superclass))
        for i in range(np.shape(outputs)[0]):
            for j in range(args.num_fine_classes):
                for k in range(args.num_superclass):
                    if cluster_result[j] == k or u_kj[j, k] >= u_t:
                        B_o_ik[i, k] += outputs[i, j]
        # ================== prepare fine_result ====================
        fine_result = torch.zeros(
            np.shape(fine_out[1])[0], args.num_fine_classes)
        for i in range(np.shape(fine_out[1])[0]):
            result_upper = torch.zeros(
                np.shape(fine_out[1])[0], args.num_fine_classes)
            result_lower = torch.zeros(
                np.shape(fine_out[1])[0], args.num_fine_classes)
            for k in range(args.num_superclass):
                result_upper[i, :] += B_o_ik[i, k] * fine_out[k][i, :]
                result_lower[i, :] += fine_out[k][i, :]
            for j in range(args.num_fine_classes):
                fine_result[i, j] = result_upper[i, j] / result_lower[i, j]
        # ====================== prepare t_k ========================
        t_k = torch.zeros(args.num_superclass)
        total = 0
        for k in range(args.num_superclass):
            for i in range(np.shape(fine_out[1])[0]):
                for j in range(args.num_fine_classes):
                    if cluster_result[j] == k or u_kj[j, k] >= u_t:
                        t_k[k] += 1
        t_k = t_k / torch.sum(t_k)
        # =================== finish preperation ====================
        loss = fine_tuning_loss(fine_result,
                                torch.max(targets, 1)[1], t_k, B_o_ik, args)
        loss.backward()  # Backward Propagation
        for i in range(args.num_superclass):
            optimizer_fine[i].step()  # Optimizer update
        train_loss += loss.item()
        _, predicted = torch.max(fine_result.data, 1)
        _, targets = torch.max(targets.data, 1)
        num_ins = targets.size(0)
        correct = predicted.eq(targets.data).cpu().sum()
        acc = 100. * correct.item() / num_ins
        sys.stdout.write('\r')
        sys.stdout.write(
            'Pre-train Epoch [%3d/%3d] Iter [%3d/%3d]\t\t Loss: %.4f Accuracy: %.3f%%'
            % (epoch, args.num_epochs_pretrain, batch_idx + 1,
               (trainloader.dataset.train_data.shape[0] //
                args.train_batch_size) + 1, loss.item(), acc))
        sys.stdout.flush()
    return acc
Exemplo n.º 21
0
def pretrain_fine(epoch, cluster_result=None, u_kj=None):
    save_point = cf.model_dir + args.dataset
    net.share.train()
    net.croase.train()
    for i in range(args.num_superclass):
        net.fines[i].train()

    train_loss = 0
    optimizer_share, lr = get_optim(net.share,
                                    args,
                                    mode='preTrain',
                                    epoch=epoch)
    optimizer_croase, lr = get_optim(net.croase,
                                     args,
                                     mode='preTrain',
                                     epoch=epoch)
    optimizer_fine = {}
    for k in range(args.num_superclass):
        for para in list(net.fines[k].parameters())[:-9]:
            para.requires_grad = False
        optimizer_fine[k], lr = get_optim(net.fines[k],
                                          args,
                                          mode='preTrain',
                                          epoch=epoch)
    if epoch == 1:
        print('\nprevious model activated')
        net.share.load_state_dict(torch.load(save_point + '/share_params.pkl'))
        net.croase.load_state_dict(
            torch.load(save_point + '/croase_params.pkl'))
        for i in range(args.num_superclass):
            net.fines[i].load_state_dict(
                torch.load(save_point + '/croase_params.pkl'))

    print('\ntrain data-loader activated')
    Data = enumerate(trainloader)
    for batch_idx, (inputs, targets) in Data:
        if batch_idx >= args.num_test:
            break
        if cf.use_cuda:
            inputs, targets = inputs.cuda(), targets.cuda()  # GPU settings
        optimizer_share.zero_grad()
        optimizer_croase.zero_grad()
        for i in range(args.num_superclass):
            optimizer_fine[i].zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        share = net.share.encoder(inputs)
        outputs = net.croase.independent(share)  # Forward Propagation
        fine_out = {}
        fine_target = {}
        fine_result = {}
        # ==================== divide the fine result =====================
        for k in range(args.num_superclass):
            fine_out[k] = []
            fine_target[k] = []
            fine_result[k] = net.fines[k].independent(share)
        for i in range(np.shape(targets)[0]):
            for j in range(args.num_fine_classes):
                if j == torch.max(targets, 1)[1][i]:
                    for k in range(args.num_superclass):
                        if cluster_result[j] == k or u_kj[j, k] >= u_t:
                            if np.shape(fine_out[k])[0] == 0:
                                fine_out[k] = torch.reshape(
                                    fine_result[k][i, :], [1, 10])
                                fine_target[k] = torch.reshape(
                                    targets[i, :], [1, 10])
                            else:
                                fine_out[k] = torch.cat(
                                    (fine_out[k],
                                     torch.reshape(fine_result[k][i, :],
                                                   [1, 10])), 0)
                                fine_target[k] = torch.cat(
                                    (fine_target[k],
                                     torch.reshape(targets[i, :], [1, 10])), 0)

        fine_loss = {}
        for k in range(args.num_superclass):
            fine_loss[k] = pred_loss(fine_out[k], fine_target[k])
            if k == 0:
                loss = fine_loss[k]
            else:
                loss += fine_loss[k]
        loss.backward()  # Backward Propagation
        for k in range(args.num_superclass):
            optimizer_fine[k].step()  # Optimizer update
        train_loss += loss.item()
        for k in range(args.num_superclass):
            if k == 0:
                predicted = torch.max(fine_out[k].data, 1)[1]
                targets = torch.max(fine_target[k].data, 1)[1]
            else:
                predicted = torch.cat(
                    (predicted, torch.max(fine_out[k].data, 1)[1]), 0)
                targets = torch.cat(
                    (targets, torch.max(fine_target[k].data, 1)[1]), 0)
        num_ins = targets.size(0)
        correct = predicted.eq(targets.data).cpu().sum()
        acc = 100. * correct.item() / num_ins
        sys.stdout.write('\r')
        sys.stdout.write(
            'Pre-train Epoch [%3d/%3d] Iter [%3d/%3d]\t\t Loss: %.4f Accuracy: %.3f%%'
            % (epoch, args.num_epochs_pretrain, batch_idx + 1,
               (trainloader.dataset.train_data.shape[0] //
                args.train_batch_size) + 1, loss.item(), acc))
        sys.stdout.flush()
    return acc