Ejemplo n.º 1
0
def train():
    train_loader = generator.Generator(args.dataset_root,
                                       args,
                                       partition='train',
                                       dataset=args.dataset)
    logger.info('Batch size: ' + str(args.batch_size))

    #Try to load models
    enc_nn = models.load_model('enc_nn', args)
    metric_nn = models.load_model('metric_nn', args)

    if enc_nn is None or metric_nn is None:
        enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

    if args.cuda:
        enc_nn.cuda()
        metric_nn.cuda()

    logger.info(str(enc_nn))
    logger.info(str(metric_nn))

    weight_decay = 0
    if args.dataset == 'mini_imagenet':
        logger.info('Weight decay ' + str(1e-6))
        weight_decay = 1e-6
    opt_enc_nn = optim.Adam(enc_nn.parameters(),
                            lr=args.lr,
                            weight_decay=weight_decay)
    opt_metric_nn = optim.Adam(metric_nn.parameters(),
                               lr=args.lr,
                               weight_decay=weight_decay)

    model_summary([enc_nn, metric_nn])
    optimizer_summary([opt_enc_nn, opt_metric_nn])
    enc_nn.train()
    metric_nn.train()
    counter = 0
    total_loss = 0
    val_acc, val_acc_aux = 0, 0
    test_acc = 0
    for batch_idx in range(args.iterations):

        ####################
        # Train
        ####################
        data = train_loader.get_task_batch(
            batch_size=args.batch_size,
            n_way=args.train_N_way,
            unlabeled_extra=args.unlabeled_extra,
            num_shots=args.train_N_shots,
            cuda=args.cuda,
            variable=True)
        [
            batch_x, label_x, _, _, batches_xi, labels_yi, oracles_yi,
            hidden_labels
        ] = data

        opt_enc_nn.zero_grad()
        opt_metric_nn.zero_grad()

        loss_d_metric = train_batch(model=[enc_nn, metric_nn, softmax_module],
                                    data=[
                                        batch_x, label_x, batches_xi,
                                        labels_yi, oracles_yi, hidden_labels
                                    ])

        opt_enc_nn.step()
        opt_metric_nn.step()

        adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn],
                             lr=args.lr,
                             iter=batch_idx)

        ####################
        # Display
        ####################
        counter += 1
        total_loss += loss_d_metric.data[0]
        if batch_idx % args.log_interval == 0:
            display_str = 'Train Iter: {}'.format(batch_idx)
            display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss /
                                                            counter)
            logger.info(display_str)
            counter = 0
            total_loss = 0

        ####################
        # Test
        ####################
        if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 20:
            if batch_idx == 20:
                test_samples = 100
            else:
                test_samples = 3000
            if args.dataset == 'mini_imagenet':
                val_acc_aux = test.test_one_shot(
                    args,
                    model=[enc_nn, metric_nn, softmax_module],
                    test_samples=test_samples * 5,
                    partition='val')
            test_acc_aux = test.test_one_shot(
                args,
                model=[enc_nn, metric_nn, softmax_module],
                test_samples=test_samples * 5,
                partition='test')
            test.test_one_shot(args,
                               model=[enc_nn, metric_nn, softmax_module],
                               test_samples=test_samples,
                               partition='train')
            enc_nn.train()
            metric_nn.train()

            if val_acc_aux is not None and val_acc_aux >= val_acc:
                test_acc = test_acc_aux
                val_acc = val_acc_aux

            if args.dataset == 'mini_imagenet':
                logger.info("Best test accuracy {:.4f} \n".format(test_acc))

        ####################
        # Save model
        ####################
        if (batch_idx + 1) % args.save_interval == 0:
            logger.info("saving model...")
            torch.save(enc_nn,
                       os.path.join(logger.get_logger_dir(), 'enc_nn.t7'))
            torch.save(metric_nn,
                       os.path.join(logger.get_logger_dir(), 'metric_nn.t7'))

    # Test after training
    test.test_one_shot(args,
                       model=[enc_nn, metric_nn, softmax_module],
                       test_samples=args.test_samples)
Ejemplo n.º 2
0
def train():
    """Main function used for training for model. Keeps iterating and updating parameters until early stop condition is reached."""

    #Generator is used to sample bacthes.
    train_loader = generator.Generator(args.dataset_root,
                                       args,
                                       partition='train',
                                       dataset=args.dataset)

    io.cprint('Batch size: ' + str(args.batch_size))
    print("Learning rate is " + str(args.lr))

    #Try to load models
    enc_nn = models.load_model('enc_nn', args, io)
    metric_nn = models.load_model('metric_nn', args, io)

    #creates models
    if enc_nn is None or metric_nn is None:
        enc_nn, metric_nn = models.create_models(args, train_loader)

    softmax_module = models.SoftmaxModule()
    if args.cuda:
        enc_nn.cuda()
        metric_nn.cuda()

    io.cprint(str(enc_nn))
    io.cprint(str(metric_nn))

    weight_decay = 0
    if args.dataset == 'sensor':
        print('Weight decay ' + str(1e-6))
        weight_decay = 1e-6

    opt_enc_nn = optim.Adam(filter(lambda p: p.requires_grad,
                                   enc_nn.parameters()),
                            lr=args.lr,
                            weight_decay=weight_decay)
    opt_metric_nn = optim.Adam(metric_nn.parameters(),
                               lr=args.lr,
                               weight_decay=weight_decay)

    enc_nn.train()
    metric_nn.train()
    counter = 0
    total_loss = 0
    test_cycle = 0
    batch_idx = 0

    start = time.time()
    print("starting time count")
    e_stop = early_stop.EarlyStopping()

    #Start training loop
    while e_stop.early_stop is False:
        ####################
        # Train
        ####################
        #Load training batch
        data, _ = train_loader.get_task_batch(batch_size=args.batch_size,
                                              cuda=args.cuda,
                                              variable=True)

        [batch_x, label_x, _, _, batches_xi, labels_yi] = data

        opt_enc_nn.zero_grad()
        opt_metric_nn.zero_grad()

        #Calculate loss
        loss_d_metric = train_batch(
            model=[enc_nn, metric_nn, softmax_module],
            data=[batch_x, label_x, batches_xi, labels_yi])
        #Update parameter
        opt_enc_nn.step()
        opt_metric_nn.step()

        #Adjust learning rate
        adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn],
                             lr=args.lr,
                             iter=batch_idx)

        ####################
        # Display
        ####################
        counter += 1
        total_loss += loss_d_metric.item()
        if batch_idx % args.log_interval == 0:
            display_str = 'Train Iter: {}'.format(batch_idx)
            display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss /
                                                            counter)
            io.cprint(display_str)
            counter = 0
            total_loss = 0

        ####################
        # Test
        ####################
        #Testing at specific itnervals
        if (batch_idx + 1) % args.test_interval == 0 or batch_idx == 0:
            if batch_idx == 20:
                test_samples = 200
            else:
                test_samples = 300

            e_stop = test.test_one_shot(
                e_stop,
                test_cycle,
                args,
                model=[enc_nn, metric_nn, softmax_module],
                test_samples=test_samples,
                partition='val')

            enc_nn.train()
            metric_nn.train()

            test_cycle = test_cycle + 1

            end = time.time()
            io.cprint("Time elapsed : " + str(end - start))
            print("Time elapsed : " + str(end - start))

        ####################
        # Save model
        ####################
        #Save model at specific interval
        if (batch_idx + 1) % args.save_interval == 0:
            torch.save(enc_nn,
                       'checkpoints/%s/models/enc_nn.t7' % args.exp_name)
            torch.save(metric_nn,
                       'checkpoints/%s/models/metric_nn.t7' % args.exp_name)

        batch_idx = batch_idx + 1

    #Test after training
    #Load best model
    final_enc_nn = models.load_best_model('enc_nn', io)
    final_metric_nn = models.load_best_model('metric_nn', io)

    final_enc_nn.cuda()
    final_metric_nn.cuda()

    test.test_one_shot(e_stop,
                       test_cycle,
                       args,
                       model=[final_enc_nn, final_metric_nn, softmax_module],
                       test_samples=args.test_samples,
                       partition='test')