コード例 #1
0
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

        # os.mkdir(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)

    model = get_architecture(args.arch, args.dataset)

    logfilename = os.path.join(args.outdir, 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args.lr,
                    momentum=args.momentum,
                    weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    for epoch in range(args.epochs):
        scheduler.step(epoch)
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args.noise_sd)
        test_loss, test_acc = test(test_loader, model, criterion,
                                   args.noise_sd)
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))
コード例 #2
0
def main(args):
    pics = os.listdir(PICS_PATH)
    data_set = CarPlateLoader(pics)
    data_loader = DataLoader(data_set, batch_size=50, shuffle=True, num_workers=8)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    model = Net().to(device)
    if os.path.exists("car_plate.pt"):
        model.load_state_dict(torch.load("car_plate.pt"))
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=1e-5)
    scheduler = StepLR(optimizer, step_size=2, gamma=0.9)

    for i in range(args.epoes):
        model.train()
        for i_batch, sample_batched in enumerate(data_loader):
            optimizer.zero_grad()
            img_tensor = sample_batched["img"].to(device)
            label_tensor = sample_batched["label"].to(device)
            output = model(img_tensor)
            loss = F.mse_loss(output, label_tensor)
            loss.backward()
            optimizer.step()
            if i_batch % 10 == 0:
                print(i, i_batch, "loss="+str(loss.cpu().item()), "lr="+str(scheduler.get_lr()))
        scheduler.step()

    torch.save(model.state_dict(), "car_plate.pt")
コード例 #3
0
def train(env_name,
          iterations,
          seed=42,
          model=None,
          render=True,
          lr=1e-3,
          batch_size=16,
          loss_base=2):
    # Create the environment
    env = make_env(env_name, seed)
    # Get PyTorch device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Declare writer
    writer = SummaryWriter('runs/reconstruction_logs/')
    # Create Qnetwork
    state = torch.load(
        model, map_location="cuda" if torch.cuda.is_available() else "cpu")
    net = QNetwork(env.observation_space,
                   env.action_space,
                   arch=state['arch'],
                   dueling=state.get('dueling', False)).to(device)
    net.load_state_dict(state['state_dict'])

    # Create the decoder and the optimizer
    dqn_decoder = DqnDecoder(net).to(device)
    optimizer = optim.Adam(dqn_decoder.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=1000, gamma=0.5)

    # Training loop
    data_generator = generate_batch(env, net, device, batch_size)
    for i in range(iterations):
        scheduler.step()
        writer.add_scalar('reconstruction/lr', scheduler.get_lr()[0], i)
        optimizer.zero_grad()
        # Get batch and AE reconstruction
        batch = next(data_generator)
        reconstruction = dqn_decoder(batch, net)
        # Compute loss and backpropagate
        loss = (
            (batch - reconstruction)**loss_base).abs().sum() / batch.size(0)
        loss.backward()
        optimizer.step()
        writer.add_scalar('reconstruction/loss', loss, i)

        # Visualize
        if i % 50 == 0:
            original = vutils.make_grid(batch[:4],
                                        normalize=True,
                                        scale_each=True)
            reco = vutils.make_grid(reconstruction[:4],
                                    normalize=True,
                                    scale_each=True)
            writer.add_image('original', original, i)
            writer.add_image('reconstruction', reco, i)

    # Save the decoder
    output_name = model.split('/')[-1].split('.')[0] + '_decoder' + '.pth'
    torch.save(dqn_decoder.state_dict(), output_name)
コード例 #4
0
def _train(backbone_name: str, path_to_data_dir: str, path_to_checkpoints_dir: str):
    dataset = Dataset(path_to_data_dir, mode=Dataset.Mode.TRAIN)
    dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=8, pin_memory=True)

    backbone = Interface.from_name(backbone_name)(pretrained=True)
    model = Model(backbone).cuda()
    optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=0.0005)
    scheduler = StepLR(optimizer, step_size=50000, gamma=0.1)

    step = 0
    time_checkpoint = time.time()
    losses = deque(maxlen=100)
    should_stop = False

    num_steps_to_display = 20
    num_steps_to_snapshot = 10000
    num_steps_to_stop_training = 70000

    print('Start training')

    while not should_stop:
        for batch_index, (_, image_batch, _, bboxes_batch, labels_batch) in enumerate(dataloader):
            assert image_batch.shape[0] == 1, 'only batch size of 1 is supported'

            image = image_batch[0].cuda()
            bboxes = bboxes_batch[0].cuda()
            labels = labels_batch[0].cuda()

            forward_input = Model.ForwardInput.Train(image, gt_classes=labels, gt_bboxes=bboxes)
            forward_output: Model.ForwardOutput.Train = model.train().forward(forward_input)

            loss = forward_output.anchor_objectness_loss + forward_output.anchor_transformer_loss + \
                forward_output.proposal_class_loss + forward_output.proposal_transformer_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            losses.append(loss.item())
            step += 1

            if step % num_steps_to_display == 0:
                elapsed_time = time.time() - time_checkpoint
                time_checkpoint = time.time()
                steps_per_sec = num_steps_to_display / elapsed_time
                avg_loss = sum(losses) / len(losses)
                lr = scheduler.get_lr()[0]
                print(f'[Step {step}] Avg. Loss = {avg_loss:.6f}, Learning Rate = {lr} ({steps_per_sec:.2f} steps/sec)')

            if step % num_steps_to_snapshot == 0:
                path_to_checkpoint = model.save(path_to_checkpoints_dir, step)
                print(f'Model saved to {path_to_checkpoint}')

            if step == num_steps_to_stop_training:
                should_stop = True
                break

    print('Done')
コード例 #5
0
ファイル: LSTM.py プロジェクト: EEEGUI/Dengue
def train(city):
    print('start training ... ')
    is_val = False
    model = RNN(INPUT_SIZE, HIDDEN_SIZE, NUM_LAYERS).to(DEVICE)
    criterion = nn.L1Loss()
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_data = DengueData(city, 'train_all')
    train_loader = torch.utils.data.DataLoader(dataset=train_data,
                                               batch_size=BATCH_SIZE,
                                               shuffle=True)
    total_step = len(train_loader)

    logger = Logger('../output/logs_%s' % city)
    scheduler = StepLR(optimizer, LR_STEP_SIZE * total_step, LR_GAMMA)
    for epoch in range(NUM_EPOCHS):
        for i, (features, labels) in enumerate(train_loader):
            features = features.reshape(-1, SEQUENCE_LENGTH,
                                        INPUT_SIZE).float().to(DEVICE)
            labels = labels.to(DEVICE).float()
            outputs = model(features)
            print(outputs)
            loss = criterion(outputs.reshape(-1), labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            if (i + 1) % 5 == 0:
                print('Epoch[%d/%d], step[%d/%d], Loss:%.5f' %
                      (epoch + 1, NUM_EPOCHS, i + 1, total_step, loss.item()))
                # print('epoch:%d, step:%d, learning_rate:%.5f' % (epoch, i, scheduler.get_lr()[0]))
                logger.scalar_summary('loss', loss.item(),
                                      epoch * total_step + i + 1)
                logger.scalar_summary('learning_rate',
                                      scheduler.get_lr()[0],
                                      epoch * total_step + i + 1)

        if is_val:
            print('Epoch % d Start validating' % epoch)
            val_data = DengueData(city, 'val')
            val_loader = torch.utils.data.DataLoader(dataset=val_data,
                                                     batch_size=16,
                                                     shuffle=False)
            with torch.no_grad():
                mae = 0
                val_size = len(val_data)
                for features, labels in val_loader:
                    features = features.reshape(-1, SEQUENCE_LENGTH,
                                                INPUT_SIZE).float().to(DEVICE)
                    labels = labels.to(DEVICE).float()
                    outputs = model(features)
                    mae += criterion(outputs.reshape(-1),
                                     labels) * labels.size(0) / val_size
                print('The MAE of val dataset is %.5f' % mae)
                logger.scalar_summary('accuracy', mae, epoch)

    torch.save(model.state_dict(), '../config/%s_model.ckpt' % city)
コード例 #6
0
def test_ClipLR():
    target = [0.1**i for i in range(4)] + [1e-3]
    optimizer = torch.optim.SGD([torch.nn.Parameter()], lr=1.0)
    lr_scheduler = StepLR(optimizer, step_size=1, gamma=0.1)
    lr_scheduler = ClipLR(lr_scheduler, min=1e-3)
    output = []
    for epoch in range(5):
        lr_scheduler.step()
        output.extend(lr_scheduler.get_lr())
    np.testing.assert_allclose(output, target, atol=1e-6)
コード例 #7
0
def main(args):
    pics = os.listdir(PICS_PATH)
    data_set = CarPlateLoader(pics)
    data_loader = DataLoader(data_set,
                             batch_size=args.batch,
                             shuffle=True,
                             num_workers=8)

    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    model = Net(args.batch, device, 2).to(device)
    if os.path.exists("car_plate.pt"):
        model.load_state_dict(torch.load("car_plate.pt"))
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = StepLR(optimizer, step_size=10, gamma=0.9)
    for i in range(args.epoes):
        model.train()
        for i_batch, sample_batched in enumerate(data_loader):
            optimizer.zero_grad()
            img_tensor = sample_batched["img"].to(device)
            label_tensor = sample_batched["label"].to(device)
            if label_tensor.shape[0] != args.batch:
                continue
            output = model(img_tensor)
            inputs_size = torch.zeros(label_tensor.shape[0], dtype=torch.int)
            for ii in range(inputs_size.shape[0]):
                inputs_size[ii] = output.shape[0]
            targets_size = torch.zeros(label_tensor.shape[0], dtype=torch.int)
            for iii in range(targets_size.shape[0]):
                targets_size[iii] = label_tensor[iii].shape[0]
            loss = F.ctc_loss(output,
                              label_tensor,
                              inputs_size,
                              targets_size,
                              blank=65)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 400)
            optimizer.step()
            if i_batch % 100 == 0:
                tmp = torch.squeeze(output.cpu()[:, 0, :])
                values, indexs = tmp.max(1)
                print(
                    i, i_batch, "loss=" + str(loss.cpu().item()),
                    "lr=" + str(scheduler.get_lr()),
                    ",random check: label is " + parseOutput(label_tensor[0]) +
                    " ,network predict is " + parseOutput(indexs))
        scheduler.step()

    torch.save(model.state_dict(), "car_plate.pt")
コード例 #8
0
def main():
    if args['gpu']:
        os.environ['CUDA_VISIBLE_DEVICES'] = args['gpu']

    if not os.path.exists(args['outdir']):
        os.mkdir(args['outdir'])

    train_loader, test_loader = loaddata(args)

    if torch.cuda.is_available():
        model = loadmodel(args)
        model = model.cuda()

    logfilename = os.path.join(args['outdir'], 'log.txt')
    init_logfile(logfilename,
                 "epoch\ttime\tlr\ttrain loss\ttrain acc\ttestloss\ttest acc")

    criterion = CrossEntropyLoss().cuda()
    optimizer = SGD(model.parameters(),
                    lr=args['lr'],
                    momentum=args['momentum'],
                    weight_decay=args['weight_decay'])
    scheduler = StepLR(optimizer,
                       step_size=args['lr_step_size'],
                       gamma=args['gamma'])

    for epoch in range(args['epochs']):
        scheduler.step(epoch)
        before = time.time()
        train_loss, train_acc = train(train_loader, model, criterion,
                                      optimizer, epoch, args['noise_sd'])
        test_loss, test_acc = test(test_loader, model, criterion,
                                   args['noise_sd'])
        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, str(datetime.timedelta(seconds=(after - before))),
                scheduler.get_lr()[0], train_loss, train_acc, test_loss,
                test_acc))

        torch.save(
            {
                'epoch': epoch + 1,
                'dataset': args['dataset'],
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args['outdir'], 'checkpoint.pth.tar'))
コード例 #9
0
ファイル: steplr.py プロジェクト: julian-risch/WI-IAT2020
class StepLREpochCallback(Callback):
    def __init__(self, step_size=30, gamma=0.1, on_iteration_every=None):
        """

        :param step_size: StepLR parameter
        :param gamma: StepLR parameter
        :param on_iteration_every: Whether to call it after every X batches if None it is called after every epoch
        """
        super(StepLREpochCallback, self).__init__()
        self.scheduler = None
        self.step_size = step_size
        self.gamma = gamma
        self.last_lr = -1
        self.on_iteration_every = on_iteration_every

    def on_train_begin(self, logs=None):
        self.scheduler = StepLR(self.trainer._optimizer, step_size=self.step_size, gamma=self.gamma)

    def on_epoch_end(self, epoch, logs=None):
        if self.on_iteration_every is None:
            self.scheduler.step()
            self.check_if_changed()

    def on_batch_end(self, iteration, logs=None):
        if self.on_iteration_every is not None and iteration % self.on_iteration_every == 0:
            self.scheduler.step()
            self.check_if_changed()

    def check_if_changed(self):
        current_lr = self.scheduler.get_lr()[0]

        if self.last_lr != -1:
            if self.last_lr != current_lr:
                self.last_lr = current_lr
                tqdm.write(f'StepLR changed Learning Rate to {self.last_lr}')
        else:
            self.last_lr = current_lr
コード例 #10
0
def main(data_dir, output_dir, log_dir, epochs, batch, lr, model_kind):
    # use GPU
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # get data loaders
    training = get_dataloader(train=True, batch_size=batch, data_dir=data_dir)
    testing = get_dataloader(train=False, batch_size=batch, data_dir=data_dir)

    # model
    if model_kind == 'linear':
        model = Logistic().to(device)
    elif model_kind == 'nn':
        model = NeuralNework().to(device)
    else:
        model = CNN().to(device)

    info('Model')
    print(model)

    # cost function
    cost = torch.nn.BCELoss()

    # optimizers
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = StepLR(optimizer, 5)

    for epoch in range(1, epochs + 1):
        info('Epoch {}'.format(epoch))
        scheduler.step()
        print('Current learning rate: {}'.format(scheduler.get_lr()))
        train(model, device, training, cost, optimizer, epoch)
        test(model, device, testing, cost)

    # save model
    info('Saving Model')
    save_model(model, device, output_dir, 'model')
コード例 #11
0
ファイル: train.py プロジェクト: jireh-father/waveglow
def train(num_gpus,
          rank,
          group_name,
          output_directory,
          epochs,
          learning_rate,
          sigma,
          iters_per_checkpoint,
          batch_size,
          seed,
          fp16_run,
          checkpoint_path,
          with_tensorboard,
          num_workers=4):
    print("num_workers", num_workers)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        init_distributed(rank, num_gpus, group_name, **dist_config)
    # =====END:   ADDED FOR DISTRIBUTED======

    criterion = WaveGlowLoss(sigma)
    model = WaveGlow(**waveglow_config).cuda()

    # =====START: ADDED FOR DISTRIBUTED======
    if num_gpus > 1:
        model = apply_gradient_allreduce(model)
    # =====END:   ADDED FOR DISTRIBUTED======

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=1, gamma=0.96)

    if fp16_run:
        from apex import amp
        model, optimizer = amp.initialize(model, optimizer, opt_level='O1')

    # Load checkpoint if one exists
    iteration = 0
    if checkpoint_path != "":
        model, optimizer, iteration = load_checkpoint(checkpoint_path, model,
                                                      optimizer)
        iteration += 1  # next iteration is iteration + 1

    trainset = Mel2Samp(**data_config)
    evalset = Mel2Samp(**eval_data_config)
    # =====START: ADDED FOR DISTRIBUTED======
    train_sampler = DistributedSampler(trainset) if num_gpus > 1 else None
    eval_sampler = DistributedSampler(evalset) if num_gpus > 1 else None
    # =====END:   ADDED FOR DISTRIBUTED======
    train_loader = DataLoader(trainset,
                              num_workers=num_workers,
                              shuffle=False,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=False,
                              drop_last=True)
    eval_loader = DataLoader(evalset,
                             num_workers=num_workers,
                             shuffle=False,
                             sampler=eval_sampler,
                             batch_size=batch_size,
                             pin_memory=False,
                             drop_last=True)

    # Get shared output_directory ready
    if rank == 0:
        if not os.path.isdir(output_directory):
            os.makedirs(output_directory)
            os.chmod(output_directory, 0o775)
        print("output directory", output_directory)

    if with_tensorboard and rank == 0:
        from tensorboardX import SummaryWriter
        logger = SummaryWriter(os.path.join(output_directory, 'logs'))

    epoch_offset = max(1, int(iteration / len(train_loader)))
    start_time = datetime.datetime.now()
    # ================ MAIN TRAINNIG LOOP! ===================
    for epoch in range(epoch_offset, epochs):
        print('Epoch:', epoch, 'LR:', scheduler.get_lr())
        elapsed = datetime.datetime.now() - start_time
        print("Epoch: [{}][els: {}] {}".format(
            datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), elapsed,
            epoch))
        model.train()
        total_loss = 0.
        for i, batch in enumerate(train_loader):
            model.zero_grad()

            if waveglow_config["multi_speaker_config"]["use_multi_speaker"]:
                mel, audio, spk_embed_or_id = batch
                spk_embed_or_id = torch.autograd.Variable(
                    spk_embed_or_id.cuda())
            else:
                mel, audio = batch
            mel = torch.autograd.Variable(mel.cuda())
            audio = torch.autograd.Variable(audio.cuda())

            if waveglow_config["multi_speaker_config"]["use_multi_speaker"]:
                outputs = model((mel, audio, spk_embed_or_id))
            else:
                outputs = model((mel, audio))

            loss = criterion(outputs)
            if num_gpus > 1:
                reduced_loss = reduce_tensor(loss.data, num_gpus).item()
            else:
                reduced_loss = loss.item()

            if fp16_run:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            optimizer.step()
            total_loss += reduced_loss
            if i > 0 and i % 10 == 0:
                elapsed = datetime.datetime.now() - start_time
                print(
                    "[{}][els: {}] epoch {},total steps{}, {}/{} steps:\t{:.9f}"
                    .format(
                        datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
                        elapsed, epoch, iteration, i, len(train_loader),
                        reduced_loss))
            if with_tensorboard and rank == 0:
                logger.add_scalar('training_loss', reduced_loss,
                                  i + len(train_loader) * epoch)

            if (iteration % iters_per_checkpoint == 0):
                if rank == 0:
                    checkpoint_path = "{}/waveglow_{}".format(
                        output_directory, iteration)
                    save_checkpoint(model, optimizer, learning_rate, iteration,
                                    checkpoint_path)

            iteration += 1
        elapsed = datetime.datetime.now() - start_time
        print("[{}][els: {}] {} epoch :\tavg loss {:.9f}".format(
            datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S"), elapsed,
            epoch, total_loss / len(train_loader)))
        scheduler.step()
        eval.eval(eval_loader, model, criterion, num_gpus, start_time, epoch,
                  waveglow_config["multi_speaker_config"]["use_multi_speaker"])
コード例 #12
0
ファイル: train.py プロジェクト: priba/graph_metric.pytorch
def main():
    print('Loss & Optimizer')
    if args.loss == 'triplet':
        args.triplet = True
        criterion = TripletLoss(margin=args.margin, swap=args.swap)
    elif args.loss == 'triplet_distance':
        args.triplet = True
        criterion = TripletLoss(margin=args.margin, swap=args.swap, dist=True)
    else:
        args.triplet = False
        criterion = ContrastiveLoss(margin=args.margin)

    print('Prepare data')
    train_loader, valid_loader, valid_gallery_loader, test_loader, test_gallery_loader, in_size = load_data(
        args.dataset,
        args.data_path,
        triplet=args.triplet,
        batch_size=args.batch_size,
        prefetch=args.prefetch,
        set_partition=args.set_partition)

    print('Create model')
    if args.model == 'GAT':
        net = models.GNN_GAT(in_size,
                             args.hidden,
                             args.out_size,
                             dropout=args.dropout)
    elif args.model == 'GRU':
        net = models.GNN_GRU(in_size,
                             args.hidden,
                             args.out_size,
                             dropout=args.dropout)

    distNet = distance.SoftHd(args.out_size)

    optimizer = torch.optim.Adam(list(net.parameters()) +
                                 list(distNet.parameters()),
                                 args.learning_rate,
                                 weight_decay=args.decay)
    scheduler = StepLR(optimizer, 5, gamma=args.gamma)

    print('Check CUDA')
    if args.cuda and args.ngpu > 1:
        print('\t* Data Parallel **NOT TESTED**')
        net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

    if args.cuda:
        print('\t* CUDA')
        net, distNet = net.cuda(), distNet.cuda()
        criterion = criterion.cuda()

    start_epoch = 0
    best_perf = 0
    early_stop_counter = 0
    if args.load is not None:
        print('Loading model')
        checkpoint = load_checkpoint(args.load)
        net.load_state_dict(checkpoint['state_dict'])
        distNet.load_state_dict(checkpoint['state_dict_dist'])
        start_epoch = checkpoint['epoch']
        best_perf = checkpoint['best_perf']

    if not args.test:
        print('***Train***')

        for epoch in range(start_epoch, args.epochs):

            loss_train = train(train_loader, [net, distNet], optimizer,
                               args.cuda, criterion, epoch)
            acc_valid, map_valid = test(valid_loader,
                                        valid_gallery_loader, [net, distNet],
                                        args.cuda,
                                        validation=True)

            # Early-Stop + Save model
            if map_valid.avg > best_perf:
                best_perf = map_valid.avg
                early_stop_counter = 0
                if args.save is not None:
                    save_checkpoint(
                        {
                            'epoch': epoch + 1,
                            'state_dict': net.state_dict(),
                            'state_dict_dist': distNet.state_dict(),
                            'best_perf': best_perf
                        },
                        directory=args.save,
                        file_name='checkpoint')
            else:
                if early_stop_counter >= args.early_stop:
                    print('Early Stop epoch {}'.format(epoch))
                    break
                early_stop_counter += 1

            # Logger
            if args.log:
                # Scalars
                logger.add_scalar('loss_train', loss_train.avg)
                logger.add_scalar('acc_valid', acc_valid.avg)
                logger.add_scalar('map_valid', map_valid.avg)
                logger.add_scalar('learning_rate', scheduler.get_lr()[0])
                logger.step()

            scheduler.step()
        # Load Best model in case of save it
        if args.save is not None:
            print('Loading best  model')
            best_model_file = os.path.join(args.save, 'checkpoint.pth')
            checkpoint = load_checkpoint(best_model_file)
            net.load_state_dict(checkpoint['state_dict'])
            distNet.load_state_dict(checkpoint['state_dict_dist'])
            print('Best model at epoch {epoch} and acc {acc}%'.format(
                epoch=checkpoint['epoch'], acc=checkpoint['best_perf']))

    print('***Valid***')
    test(valid_loader, valid_gallery_loader, [net, distNet], args.cuda)
    print('***Test***')
    test(test_loader, test_gallery_loader, [net, distNet], args.cuda)
    sys.exit()
コード例 #13
0
def _train(train_img_path, train_txt_path, val_img_path, val_txt_path,
           path_to_log_dir, path_to_restore_checkpoint_file, training_options):
    batch_size = training_options['batch_size']
    initial_learning_rate = training_options['learning_rate']
    initial_patience = training_options['patience']
    num_steps_to_show_loss = 100
    num_steps_to_check = 1000

    step = 0
    patience = initial_patience
    best_accuracy = 0.0
    duration = 0.0

    model = Model(21)
    model.cuda()

    transform = transforms.Compose([
        transforms.Resize([285, 285]),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
    train_loader = torch.utils.data.DataLoader(BarcodeDataset(
        train_img_path, train_txt_path, transform),
                                               batch_size=batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    evaluator = Evaluator(val_img_path, val_txt_path)
    optimizer = optim.SGD(model.parameters(),
                          lr=initial_learning_rate,
                          momentum=0.9,
                          weight_decay=0.0005)
    scheduler = StepLR(optimizer,
                       step_size=training_options['decay_steps'],
                       gamma=training_options['decay_rate'])

    if path_to_restore_checkpoint_file is not None:
        assert os.path.isfile(
            path_to_restore_checkpoint_file
        ), '%s not found' % path_to_restore_checkpoint_file
        step = model.restore(path_to_restore_checkpoint_file)
        scheduler.last_epoch = step
        print('Model restored from file: %s' % path_to_restore_checkpoint_file)

    path_to_losses_npy_file = os.path.join(path_to_log_dir, 'losses.npy')
    if os.path.isfile(path_to_losses_npy_file):
        losses = np.load(path_to_losses_npy_file)
    else:
        losses = np.empty([0], dtype=np.float32)

    while True:
        for batch_idx, (images, digits_labels) in enumerate(train_loader):
            start_time = time.time()
            images, digits_labels = images.cuda(), [
                digit_label.cuda() for digit_label in digits_labels
            ]
            digit2_logits, digit3_logits, digit4_logits, digit5_logits, digit6_logits, digit7_logits, digit8_logits, digit9_logits, digit10_logits, digit11_logits, digit12_logits, digit13_logits = model.train(
            )(images)
            loss = _loss(digit2_logits, digit3_logits, digit4_logits,
                         digit5_logits, digit6_logits, digit7_logits,
                         digit8_logits, digit9_logits, digit10_logits,
                         digit11_logits, digit12_logits, digit13_logits,
                         digits_labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            step += 1
            duration += time.time() - start_time

            if step % num_steps_to_show_loss == 0:
                examples_per_sec = batch_size * num_steps_to_show_loss / duration
                duration = 0.0
                print(
                    '=> %s: step %d, loss = %f, learning_rate = %f (%.1f examples/sec)'
                    % (datetime.now(), step, loss.item(),
                       scheduler.get_lr()[0], examples_per_sec))

            if step % num_steps_to_check != 0:
                continue

            losses = np.append(losses, loss.item())
            np.save(path_to_losses_npy_file, losses)

            print('=> Evaluating on validation dataset...')
            accuracy = evaluator.evaluate(model)
            print('==> accuracy = %f, best accuracy %f' %
                  (accuracy, best_accuracy))

            if accuracy > best_accuracy:
                path_to_checkpoint_file = model.store(path_to_log_dir,
                                                      step=step)
                print('=> Model saved to file: %s' % path_to_checkpoint_file)
                patience = initial_patience
                best_accuracy = accuracy
            else:
                patience -= 1

            print('=> patience = %d' % patience)
            if patience == 0:
                return
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#epoch_ind_loaded = checkpoint['epoch']
#epoch_ind = checkpoint['epoch']
epoch_ind = -1
#torch.cuda.empty_cache()
#test_pp=evaluate(test_data, model, batch_size, device)
#print("Test perplexity after the loading: ", test_pp)
#logging.debug ("Test perplexity: %.1f",test_pp)
best_valid_pp = 10000
torch.cuda.empty_cache()

for ind in range(epochs):
    epoch_ind += 1
    print(count_parameters(model))
    logging.debug("Number parameters %s", str(count_parameters(model)))
    print(scheduler.get_lr())
    logging.debug("Scheduler.get_lr() %s", str(scheduler.get_lr()))
    random.shuffle(data['x_train'])
    train_data = [
        sentence for sentence in data['x_train'][num_valid_samples:]
        if len(sentence) > 2
    ]
    valid_data = [
        sentence for sentence in data['x_train'][:num_valid_samples]
        if len(sentence) > 2
    ]
    logging.debug("START EPOCH now = %s", datetime.now())
    print("START EPOCH now = ", datetime.now())
    logging.debug("Training epoch %d", epoch_ind)
    print("Training epoch %d", epoch_ind)
    train_epoch(train_data, model, optimizer, batch_size, device)
コード例 #15
0
ファイル: solver.py プロジェクト: mikelmh025/ScanRefer
class Solver():
    def __init__(self,
                 model,
                 config,
                 dataloader,
                 optimizer,
                 stamp,
                 val_step=10,
                 detection=True,
                 reference=True,
                 use_lang_classifier=True,
                 lr_decay_step=None,
                 lr_decay_rate=None,
                 bn_decay_step=None,
                 bn_decay_rate=None):

        self.epoch = 0  # set in __call__
        self.verbose = 0  # set in __call__

        self.model = model
        self.config = config
        self.dataloader = dataloader
        self.optimizer = optimizer
        self.stamp = stamp
        self.val_step = val_step

        self.detection = detection
        self.reference = reference
        self.use_lang_classifier = use_lang_classifier

        self.lr_decay_step = lr_decay_step
        self.lr_decay_rate = lr_decay_rate
        self.bn_decay_step = bn_decay_step
        self.bn_decay_rate = bn_decay_rate

        self.best = {
            "epoch": 0,
            "loss": float("inf"),
            "ref_loss": float("inf"),
            "lang_loss": float("inf"),
            "objectness_loss": float("inf"),
            "vote_loss": float("inf"),
            "box_loss": float("inf"),
            "lang_acc": -float("inf"),
            "ref_acc": -float("inf"),
            "obj_acc": -float("inf"),
            "pos_ratio": -float("inf"),
            "neg_ratio": -float("inf"),
            "iou_rate_0.25": -float("inf"),
            "iou_rate_0.5": -float("inf")
        }

        # init log
        # contains all necessary info for all phases
        self.log = {"train": {}, "val": {}}

        # tensorboard
        os.makedirs(os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/train"),
                    exist_ok=True)
        os.makedirs(os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/val"),
                    exist_ok=True)
        self._log_writer = {
            "train":
            SummaryWriter(
                os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/train")),
            "val":
            SummaryWriter(
                os.path.join(CONF.PATH.OUTPUT, stamp, "tensorboard/val"))
        }

        # training log
        log_path = os.path.join(CONF.PATH.OUTPUT, stamp, "log.txt")
        self.log_fout = open(log_path, "a")

        # private
        # only for internal access and temporary results
        self._running_log = {}
        self._global_iter_id = 0
        self._total_iter = {}  # set in __call__

        # templates
        self.__iter_report_template = ITER_REPORT_TEMPLATE
        self.__epoch_report_template = EPOCH_REPORT_TEMPLATE
        self.__best_report_template = BEST_REPORT_TEMPLATE

        # lr scheduler
        if lr_decay_step and lr_decay_rate:
            if isinstance(lr_decay_step, list):
                self.lr_scheduler = MultiStepLR(optimizer, lr_decay_step,
                                                lr_decay_rate)
            else:
                self.lr_scheduler = StepLR(optimizer, lr_decay_step,
                                           lr_decay_rate)
        else:
            self.lr_scheduler = None

        # bn scheduler
        if bn_decay_step and bn_decay_rate:
            it = -1
            start_epoch = 0
            BN_MOMENTUM_INIT = 0.5
            BN_MOMENTUM_MAX = 0.001
            bn_lbmd = lambda it: max(
                BN_MOMENTUM_INIT * bn_decay_rate**
                (int(it / bn_decay_step)), BN_MOMENTUM_MAX)
            self.bn_scheduler = BNMomentumScheduler(model,
                                                    bn_lambda=bn_lbmd,
                                                    last_epoch=start_epoch - 1)
        else:
            self.bn_scheduler = None

    def __call__(self, epoch, verbose):
        # setting
        self.epoch = epoch
        self.verbose = verbose
        self._total_iter["train"] = len(self.dataloader["train"]) * epoch
        self._total_iter["val"] = len(self.dataloader["val"]) * self.val_step

        for epoch_id in range(epoch):
            try:
                self._log("epoch {} starting...".format(epoch_id + 1))

                # feed
                self._feed(self.dataloader["train"], "train", epoch_id)

                # save model
                self._log("saving last models...\n")
                model_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
                torch.save(self.model.state_dict(),
                           os.path.join(model_root, "model_last.pth"))

                # update lr scheduler
                if self.lr_scheduler:
                    print("update learning rate --> {}\n".format(
                        self.lr_scheduler.get_lr()))
                    self.lr_scheduler.step()

                # update bn scheduler
                if self.bn_scheduler:
                    print(
                        "update batch normalization momentum --> {}\n".format(
                            self.bn_scheduler.lmbd(
                                self.bn_scheduler.last_epoch)))
                    self.bn_scheduler.step()

            except KeyboardInterrupt:
                # finish training
                self._finish(epoch_id)
                exit()

        # finish training
        self._finish(epoch_id)

    def _log(self, info_str):
        self.log_fout.write(info_str + "\n")
        self.log_fout.flush()
        print(info_str)

    def _reset_log(self, phase):
        self.log[phase] = {
            # info
            "forward": [],
            "backward": [],
            "eval": [],
            "fetch": [],
            "iter_time": [],
            # loss (float, not torch.cuda.FloatTensor)
            "loss": [],
            "ref_loss": [],
            "lang_loss": [],
            "objectness_loss": [],
            "vote_loss": [],
            "box_loss": [],
            # scores (float, not torch.cuda.FloatTensor)
            "lang_acc": [],
            "ref_acc": [],
            "obj_acc": [],
            "pos_ratio": [],
            "neg_ratio": [],
            "iou_rate_0.25": [],
            "iou_rate_0.5": []
        }

    def _set_phase(self, phase):
        if phase == "train":
            self.model.train()
        elif phase == "val":
            self.model.eval()
        else:
            raise ValueError("invalid phase")

    def _forward(self, data_dict):
        data_dict = self.model(data_dict)

        return data_dict

    def _backward(self):
        # optimize
        self.optimizer.zero_grad()
        self._running_log["loss"].backward()
        self.optimizer.step()

    def _compute_loss(self, data_dict):
        _, data_dict = get_loss(data_dict=data_dict,
                                config=self.config,
                                detection=self.detection,
                                reference=self.reference,
                                use_lang_classifier=self.use_lang_classifier)

        # dump
        self._running_log["ref_loss"] = data_dict["ref_loss"]
        self._running_log["lang_loss"] = data_dict["lang_loss"]
        self._running_log["objectness_loss"] = data_dict["objectness_loss"]
        self._running_log["vote_loss"] = data_dict["vote_loss"]
        self._running_log["box_loss"] = data_dict["box_loss"]
        self._running_log["loss"] = data_dict["loss"]

    def _eval(self, data_dict):
        data_dict = get_eval(data_dict=data_dict,
                             config=self.config,
                             reference=self.reference,
                             use_lang_classifier=self.use_lang_classifier)

        # dump
        self._running_log["lang_acc"] = data_dict["lang_acc"].item()
        self._running_log["ref_acc"] = np.mean(data_dict["ref_acc"])
        self._running_log["obj_acc"] = data_dict["obj_acc"].item()
        self._running_log["pos_ratio"] = data_dict["pos_ratio"].item()
        self._running_log["neg_ratio"] = data_dict["neg_ratio"].item()
        self._running_log["iou_rate_0.25"] = np.mean(
            data_dict["ref_iou_rate_0.25"])
        self._running_log["iou_rate_0.5"] = np.mean(
            data_dict["ref_iou_rate_0.5"])

    def _feed(self, dataloader, phase, epoch_id):
        # switch mode
        self._set_phase(phase)

        # re-init log
        self._reset_log(phase)

        # change dataloader
        dataloader = dataloader if phase == "train" else tqdm(dataloader)

        # multiple = 0
        # total = 0
        # for data_dict in dataloader:
        #     # for key in data_dict:
        #     #     data_dict[key] = data_dict[key].cuda()
        #     data_dict['unique_multiple'] = data_dict['unique_multiple'].cuda()
        #     multiple += torch.sum(data_dict['unique_multiple']).item()
        #     total += data_dict['unique_multiple'].shape[0]

        # print("multiple",multiple, "total",total, "rate", multiple/total)
        # import sys
        # sys.exit()

        for data_dict in dataloader:
            # move to cuda
            for key in data_dict:
                data_dict[key] = data_dict[key].cuda()

            # initialize the running loss
            self._running_log = {
                # loss
                "loss": 0,
                "ref_loss": 0,
                "lang_loss": 0,
                "objectness_loss": 0,
                "vote_loss": 0,
                "box_loss": 0,
                # acc
                "lang_acc": 0,
                "ref_acc": 0,
                "obj_acc": 0,
                "pos_ratio": 0,
                "neg_ratio": 0,
                "iou_rate_0.25": 0,
                "iou_rate_0.5": 0
            }

            # load
            self.log[phase]["fetch"].append(
                data_dict["load_time"].sum().item())

            with torch.autograd.set_detect_anomaly(True):
                # forward
                start = time.time()
                data_dict = self._forward(data_dict)
                self._compute_loss(data_dict)
                self.log[phase]["forward"].append(time.time() - start)

                # backward
                if phase == "train":
                    start = time.time()
                    self._backward()
                    self.log[phase]["backward"].append(time.time() - start)

            # eval
            start = time.time()
            self._eval(data_dict)
            self.log[phase]["eval"].append(time.time() - start)

            # record log
            self.log[phase]["loss"].append(self._running_log["loss"].item())
            self.log[phase]["ref_loss"].append(
                self._running_log["ref_loss"].item())
            self.log[phase]["lang_loss"].append(
                self._running_log["lang_loss"].item())
            self.log[phase]["objectness_loss"].append(
                self._running_log["objectness_loss"].item())
            self.log[phase]["vote_loss"].append(
                self._running_log["vote_loss"].item())
            self.log[phase]["box_loss"].append(
                self._running_log["box_loss"].item())

            self.log[phase]["lang_acc"].append(self._running_log["lang_acc"])
            self.log[phase]["ref_acc"].append(self._running_log["ref_acc"])
            self.log[phase]["obj_acc"].append(self._running_log["obj_acc"])
            self.log[phase]["pos_ratio"].append(self._running_log["pos_ratio"])
            self.log[phase]["neg_ratio"].append(self._running_log["neg_ratio"])
            self.log[phase]["iou_rate_0.25"].append(
                self._running_log["iou_rate_0.25"])
            self.log[phase]["iou_rate_0.5"].append(
                self._running_log["iou_rate_0.5"])

            # report
            if phase == "train":
                iter_time = self.log[phase]["fetch"][-1]
                iter_time += self.log[phase]["forward"][-1]
                iter_time += self.log[phase]["backward"][-1]
                iter_time += self.log[phase]["eval"][-1]
                self.log[phase]["iter_time"].append(iter_time)
                if (self._global_iter_id + 1) % self.verbose == 0:
                    self._train_report(epoch_id)

                # evaluation
                if self._global_iter_id % self.val_step == 0:
                    print("evaluating...")
                    # val
                    self._feed(self.dataloader["val"], "val", epoch_id)
                    self._dump_log("val")
                    self._set_phase("train")
                    self._epoch_report(epoch_id)

                # dump log
                self._dump_log("train")
                self._global_iter_id += 1

        # check best
        if phase == "val":
            cur_criterion = "iou_rate_0.5"
            cur_best = np.mean(self.log[phase][cur_criterion])
            if cur_best > self.best[cur_criterion]:
                self._log("best {} achieved: {}".format(
                    cur_criterion, cur_best))
                self._log("current train_loss: {}".format(
                    np.mean(self.log["train"]["loss"])))
                self._log("current val_loss: {}".format(
                    np.mean(self.log["val"]["loss"])))
                self.best["epoch"] = epoch_id + 1
                self.best["loss"] = np.mean(self.log[phase]["loss"])
                self.best["ref_loss"] = np.mean(self.log[phase]["ref_loss"])
                self.best["lang_loss"] = np.mean(self.log[phase]["lang_loss"])
                self.best["objectness_loss"] = np.mean(
                    self.log[phase]["objectness_loss"])
                self.best["vote_loss"] = np.mean(self.log[phase]["vote_loss"])
                self.best["box_loss"] = np.mean(self.log[phase]["box_loss"])
                self.best["lang_acc"] = np.mean(self.log[phase]["lang_acc"])
                self.best["ref_acc"] = np.mean(self.log[phase]["ref_acc"])
                self.best["obj_acc"] = np.mean(self.log[phase]["obj_acc"])
                self.best["pos_ratio"] = np.mean(self.log[phase]["pos_ratio"])
                self.best["neg_ratio"] = np.mean(self.log[phase]["neg_ratio"])
                self.best["iou_rate_0.25"] = np.mean(
                    self.log[phase]["iou_rate_0.25"])
                self.best["iou_rate_0.5"] = np.mean(
                    self.log[phase]["iou_rate_0.5"])

                # save model
                self._log("saving best models...\n")
                model_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
                torch.save(self.model.state_dict(),
                           os.path.join(model_root, "model.pth"))

    def _dump_log(self, phase):
        log = {
            "loss": [
                "loss", "ref_loss", "lang_loss", "objectness_loss",
                "vote_loss", "box_loss"
            ],
            "score": [
                "lang_acc", "ref_acc", "obj_acc", "pos_ratio", "neg_ratio",
                "iou_rate_0.25", "iou_rate_0.5"
            ]
        }
        for key in log:
            for item in log[key]:
                self._log_writer[phase].add_scalar(
                    "{}/{}".format(key, item),
                    np.mean([v for v in self.log[phase][item]]),
                    self._global_iter_id)

    def _finish(self, epoch_id):
        # print best
        self._best_report()

        # save check point
        self._log("saving checkpoint...\n")
        save_dict = {
            "epoch": epoch_id,
            "model_state_dict": self.model.state_dict(),
            "optimizer_state_dict": self.optimizer.state_dict()
        }
        checkpoint_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
        torch.save(save_dict, os.path.join(checkpoint_root, "checkpoint.tar"))

        # save model
        self._log("saving last models...\n")
        model_root = os.path.join(CONF.PATH.OUTPUT, self.stamp)
        torch.save(self.model.state_dict(),
                   os.path.join(model_root, "model_last.pth"))

        # export
        for phase in ["train", "val"]:
            self._log_writer[phase].export_scalars_to_json(
                os.path.join(CONF.PATH.OUTPUT, self.stamp,
                             "tensorboard/{}".format(phase),
                             "all_scalars.json"))

    def _train_report(self, epoch_id):
        # compute ETA
        fetch_time = self.log["train"]["fetch"]
        forward_time = self.log["train"]["forward"]
        backward_time = self.log["train"]["backward"]
        eval_time = self.log["train"]["eval"]
        iter_time = self.log["train"]["iter_time"]

        mean_train_time = np.mean(iter_time)
        mean_est_val_time = np.mean([
            fetch + forward for fetch, forward in zip(fetch_time, forward_time)
        ])
        eta_sec = (self._total_iter["train"] - self._global_iter_id -
                   1) * mean_train_time
        eta_sec += len(self.dataloader["val"]) * np.ceil(
            self._total_iter["train"] / self.val_step) * mean_est_val_time
        eta = decode_eta(eta_sec)

        # print report
        iter_report = self.__iter_report_template.format(
            epoch_id=epoch_id + 1,
            iter_id=self._global_iter_id + 1,
            total_iter=self._total_iter["train"],
            train_loss=round(np.mean([v for v in self.log["train"]["loss"]]),
                             5),
            train_ref_loss=round(
                np.mean([v for v in self.log["train"]["ref_loss"]]), 5),
            train_lang_loss=round(
                np.mean([v for v in self.log["train"]["lang_loss"]]), 5),
            train_objectness_loss=round(
                np.mean([v for v in self.log["train"]["objectness_loss"]]), 5),
            train_vote_loss=round(
                np.mean([v for v in self.log["train"]["vote_loss"]]), 5),
            train_box_loss=round(
                np.mean([v for v in self.log["train"]["box_loss"]]), 5),
            train_lang_acc=round(
                np.mean([v for v in self.log["train"]["lang_acc"]]), 5),
            train_ref_acc=round(
                np.mean([v for v in self.log["train"]["ref_acc"]]), 5),
            train_obj_acc=round(
                np.mean([v for v in self.log["train"]["obj_acc"]]), 5),
            train_pos_ratio=round(
                np.mean([v for v in self.log["train"]["pos_ratio"]]), 5),
            train_neg_ratio=round(
                np.mean([v for v in self.log["train"]["neg_ratio"]]), 5),
            train_iou_rate_25=round(
                np.mean([v for v in self.log["train"]["iou_rate_0.25"]]), 5),
            train_iou_rate_5=round(
                np.mean([v for v in self.log["train"]["iou_rate_0.5"]]), 5),
            mean_fetch_time=round(np.mean(fetch_time), 5),
            mean_forward_time=round(np.mean(forward_time), 5),
            mean_backward_time=round(np.mean(backward_time), 5),
            mean_eval_time=round(np.mean(eval_time), 5),
            mean_iter_time=round(np.mean(iter_time), 5),
            eta_h=eta["h"],
            eta_m=eta["m"],
            eta_s=eta["s"])
        self._log(iter_report)

    def _epoch_report(self, epoch_id):
        self._log("epoch [{}/{}] done...".format(epoch_id + 1, self.epoch))
        epoch_report = self.__epoch_report_template.format(
            train_loss=round(np.mean([v for v in self.log["train"]["loss"]]),
                             5),
            train_ref_loss=round(
                np.mean([v for v in self.log["train"]["ref_loss"]]), 5),
            train_lang_loss=round(
                np.mean([v for v in self.log["train"]["lang_loss"]]), 5),
            train_objectness_loss=round(
                np.mean([v for v in self.log["train"]["objectness_loss"]]), 5),
            train_vote_loss=round(
                np.mean([v for v in self.log["train"]["vote_loss"]]), 5),
            train_box_loss=round(
                np.mean([v for v in self.log["train"]["box_loss"]]), 5),
            train_lang_acc=round(
                np.mean([v for v in self.log["train"]["lang_acc"]]), 5),
            train_ref_acc=round(
                np.mean([v for v in self.log["train"]["ref_acc"]]), 5),
            train_obj_acc=round(
                np.mean([v for v in self.log["train"]["obj_acc"]]), 5),
            train_pos_ratio=round(
                np.mean([v for v in self.log["train"]["pos_ratio"]]), 5),
            train_neg_ratio=round(
                np.mean([v for v in self.log["train"]["neg_ratio"]]), 5),
            train_iou_rate_25=round(
                np.mean([v for v in self.log["train"]["iou_rate_0.25"]]), 5),
            train_iou_rate_5=round(
                np.mean([v for v in self.log["train"]["iou_rate_0.5"]]), 5),
            val_loss=round(np.mean([v for v in self.log["val"]["loss"]]), 5),
            val_ref_loss=round(
                np.mean([v for v in self.log["val"]["ref_loss"]]), 5),
            val_lang_loss=round(
                np.mean([v for v in self.log["val"]["lang_loss"]]), 5),
            val_objectness_loss=round(
                np.mean([v for v in self.log["val"]["objectness_loss"]]), 5),
            val_vote_loss=round(
                np.mean([v for v in self.log["val"]["vote_loss"]]), 5),
            val_box_loss=round(
                np.mean([v for v in self.log["val"]["box_loss"]]), 5),
            val_lang_acc=round(
                np.mean([v for v in self.log["val"]["lang_acc"]]), 5),
            val_ref_acc=round(np.mean([v for v in self.log["val"]["ref_acc"]]),
                              5),
            val_obj_acc=round(np.mean([v for v in self.log["val"]["obj_acc"]]),
                              5),
            val_pos_ratio=round(
                np.mean([v for v in self.log["val"]["pos_ratio"]]), 5),
            val_neg_ratio=round(
                np.mean([v for v in self.log["val"]["neg_ratio"]]), 5),
            val_iou_rate_25=round(
                np.mean([v for v in self.log["val"]["iou_rate_0.25"]]), 5),
            val_iou_rate_5=round(
                np.mean([v for v in self.log["val"]["iou_rate_0.5"]]), 5),
        )
        self._log(epoch_report)

    def _best_report(self):
        self._log("training completed...")
        best_report = self.__best_report_template.format(
            epoch=self.best["epoch"],
            loss=round(self.best["loss"], 5),
            ref_loss=round(self.best["ref_loss"], 5),
            lang_loss=round(self.best["lang_loss"], 5),
            objectness_loss=round(self.best["objectness_loss"], 5),
            vote_loss=round(self.best["vote_loss"], 5),
            box_loss=round(self.best["box_loss"], 5),
            lang_acc=round(self.best["lang_acc"], 5),
            ref_acc=round(self.best["ref_acc"], 5),
            obj_acc=round(self.best["obj_acc"], 5),
            pos_ratio=round(self.best["pos_ratio"], 5),
            neg_ratio=round(self.best["neg_ratio"], 5),
            iou_rate_25=round(self.best["iou_rate_0.25"], 5),
            iou_rate_5=round(self.best["iou_rate_0.5"], 5),
        )
        self._log(best_report)
        with open(os.path.join(CONF.PATH.OUTPUT, self.stamp, "best.txt"),
                  "w") as f:
            f.write(best_report)
コード例 #16
0
    def __getitem__(self,idx):
        x, y = self.pair[idx]
#        y = util.ImageRescale(y,[0,1])
        x_tensor, y_tensor = self.ToTensor(x,y)
        return x_tensor, y_tensor

train_loader = Data.DataLoader(dataset=HQ_human_train(dataroot),
                               batch_size=batch_size, shuffle=True)

#%% 
print('training start...')
t1 = time.time()

for epoch in range(n_epoch):
    sum_loss = 0
    print('SegNet lr:{}, SynNet lr:{}'.format(Seg_sch.get_lr()[0],Syn_sch.get_lr()[0]))
    for step,(tensor_x,tensor_y) in enumerate(train_loader):
        model.train()
        
        x = Variable(tensor_x).to(device)
        y = Variable(tensor_y).to(device)
        y_seg,y_syn = model(x)
        
        l1,l2 = criterion(y,y_syn)
        loss = l1+l2
        sum_loss += loss
        
        if epoch <= 3:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
コード例 #17
0
ファイル: lstm_qsim.py プロジェクト: xujinglin/deep
        X_scores = torch.stack(score_list, 1)  #[batch_size, K=101]
        y_targets = Variable(
            torch.zeros(X_scores.size(0)).type(
                torch.LongTensor))  #[batch_size]
        if use_gpu:
            y_targets = y_targets.cuda()
        loss = criterion(X_scores, y_targets)  #y_target=0
        loss.backward()
        optimizer.step()

        running_train_loss += loss.cpu().data[0]

    #end for
    training_loss.append(running_train_loss)
    learning_rate_schedule.append(scheduler.get_lr())
    print "epoch: %4d, training loss: %.4f" % (epoch + 1, running_train_loss)

    torch.save(model, SAVE_PATH + SAVE_NAME)

    #early stopping
    patience = 4
    min_delta = 0.1
    if epoch == 0:
        patience_cnt = 0
    elif epoch > 0 and training_loss[epoch -
                                     1] - training_loss[epoch] > min_delta:
        patience_cnt = 0
    else:
        patience_cnt += 1
コード例 #18
0
ファイル: train.py プロジェクト: tawituthai/ImageClassifier
def train_network(model_,
                  epochs_,
                  learn_rate_,
                  train_loaders_,
                  valid_loaders_,
                  device_,
                  print_every_=5):
    print("## Start training ...")
    # Define loss function
    criterion = nn.NLLLoss()

    # Define Optimizer
    if model_name == 'VGG19':
        optimizer = optim.Adam(model_.classifier.parameters(), lr=learn_rate_)
    else:
        optimizer = optim.Adam(model_.fc.parameters(), lr=learn_rate_)

    # Set Learning rate scheduler
    scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

    # Parameter for training
    steps = 0
    training_loss = 0

    # For plotting
    train_losses, valid_losses, acc_trace = [], [], []

    # Loop through epochs
    for epoch in range(epochs_):
        steps = 0
        # Decay Learning Rate
        scheduler.step()

        for images, labels in train_loaders_:
            # increase steps to keep track of number of bacth
            steps += 1
            # move images and labels to device
            images, labels = images.to(device_), labels.to(device_)
            # Reset Optimizer gradient
            optimizer.zero_grad()

            # Forward pass
            log_ps = model_.forward(images)
            # Calculate loss from criterion
            loss = criterion(log_ps, labels)
            # Calculate weight gradient from back-propagation process
            loss.backward()
            # Update weight, using optimizer
            optimizer.step()

            # Keep track of training loss
            training_loss += loss.item()

            # if steps is multiple of print_every then do validation test and print stats
            if steps % print_every_ == 0:
                # Turn-off Dropout
                model_.eval()

                valid_loss = 0
                accuracy = 0

                # Validation loop
                for images_valid, labels_valid in valid_loaders_:
                    # move images and labels to device
                    images_valid, labels_valid = images_valid.to(
                        device_), labels_valid.to(device_)

                    # Do forward pass
                    log_ps_valid = model_.forward(images_valid)
                    # Calculate Loss
                    loss_valid = criterion(log_ps_valid, labels_valid)

                    # Keep track of validation loss
                    valid_loss += loss_valid.item()

                    # Calculate validation accuracy
                    # Get logit from out network, need to do exponential
                    ps_valid = torch.exp(log_ps_valid)
                    # Get top prediction of each image in each batch
                    top_ps, top_class = ps_valid.topk(1, dim=1)
                    # Define Equality
                    equality = (top_class == labels_valid.view(
                        *top_class.shape))
                    accuracy += torch.mean(equality.type(
                        torch.FloatTensor)).item()

                # Collect data for ploting
                train_losses.append(training_loss / print_every_)
                valid_losses.append(valid_loss / len(valid_loaders_))
                acc_trace.append(accuracy / len(valid_loaders_))

                # After Validation loop is done, print out stats
                print("Learning Rate: {}".format(scheduler.get_lr()))
                print(
                    f"Epoch {epoch+1}/{epochs}.. "
                    f"Steps {steps}/{len(train_loaders_)}.. "
                    f"Training loss: {training_loss/print_every_:.3f}.. "
                    f"Validation loss: {valid_loss/len(valid_loaders_):.3f}.. "
                    f"Validation accuracy: {accuracy/len(valid_loaders_):.3f}")
                print(
                    "-------------------------------------------------------\r\n"
                )

                # Reset Training loss
                training_loss = 0
                # Set model back to Training mode
                model_.train()
コード例 #19
0
def main():
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not os.path.exists(args.outdir):
        os.makedirs(args.outdir)

    # Copy code to output directory
    copy_code(args.outdir)

    train_dataset = get_dataset(args.dataset, 'train')
    test_dataset = get_dataset(args.dataset, 'test')
    pin_memory = (args.dataset == "imagenet")
    train_loader = DataLoader(train_dataset,
                              shuffle=True,
                              batch_size=args.batch,
                              num_workers=args.workers,
                              pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset,
                             shuffle=False,
                             batch_size=args.batch,
                             num_workers=args.workers,
                             pin_memory=pin_memory)
    ## This is used to test the performance of the denoiser attached to a cifar10 classifier
    cifar10_test_loader = DataLoader(get_dataset('cifar10', 'test'),
                                     shuffle=False,
                                     batch_size=args.batch,
                                     num_workers=args.workers,
                                     pin_memory=pin_memory)

    if args.pretrained_denoiser:
        checkpoint = torch.load(args.pretrained_denoiser)
        assert checkpoint['arch'] == args.arch
        denoiser = get_architecture(checkpoint['arch'], args.dataset)
        denoiser.load_state_dict(checkpoint['state_dict'])
    else:
        denoiser = get_architecture(args.arch, args.dataset)

    if args.optimizer == 'Adam':
        optimizer = Adam(denoiser.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)
    elif args.optimizer == 'SGD':
        optimizer = SGD(denoiser.parameters(),
                        lr=args.lr,
                        momentum=args.momentum,
                        weight_decay=args.weight_decay)
    elif args.optimizer == 'AdamThenSGD':
        optimizer = Adam(denoiser.parameters(),
                         lr=args.lr,
                         weight_decay=args.weight_decay)
    scheduler = StepLR(optimizer,
                       step_size=args.lr_step_size,
                       gamma=args.gamma)

    starting_epoch = 0
    logfilename = os.path.join(args.outdir, 'log.txt')

    ## Resume from checkpoint if exists and if resume flag is True
    denoiser_path = os.path.join(args.outdir, 'checkpoint.pth.tar')
    if args.resume and os.path.isfile(denoiser_path):
        print("=> loading checkpoint '{}'".format(denoiser_path))
        checkpoint = torch.load(denoiser_path,
                                map_location=lambda storage, loc: storage)
        assert checkpoint['arch'] == args.arch
        starting_epoch = checkpoint['epoch']
        denoiser.load_state_dict(checkpoint['state_dict'])
        if starting_epoch >= args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ':  # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer,
                               step_size=args.lr_step_size,
                               gamma=args.gamma)

        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})".format(
            denoiser_path, checkpoint['epoch']))
    else:
        if args.resume:
            print("=> no checkpoint found at '{}'".format(args.outdir))
        init_logfile(logfilename,
                     "epoch\ttime\tlr\ttrainloss\ttestloss\ttestAcc")

    if args.objective == 'denoising':
        criterion = MSELoss(size_average=None, reduce=None,
                            reduction='mean').cuda()
        best_loss = 1e6

    elif args.objective in ['classification', 'stability']:
        assert args.classifier != '', "Please specify a path to the classifier you want to attach the denoiser to."

        if args.classifier in IMAGENET_CLASSIFIERS:
            assert args.dataset == 'imagenet'
            # loading pretrained imagenet architectures
            clf = get_architecture(args.classifier,
                                   args.dataset,
                                   pytorch_pretrained=True)
        else:
            checkpoint = torch.load(args.classifier)
            clf = get_architecture(checkpoint['arch'], 'cifar10')
            clf.load_state_dict(checkpoint['state_dict'])
        clf.cuda().eval()
        requires_grad_(clf, False)
        criterion = CrossEntropyLoss(size_average=None,
                                     reduce=None,
                                     reduction='mean').cuda()
        best_acc = 0

    for epoch in range(starting_epoch, args.epochs):
        before = time.time()
        if args.objective == 'denoising':
            train_loss = train(train_loader, denoiser, criterion, optimizer,
                               epoch, args.noise_sd)
            test_loss = test(test_loader, denoiser, criterion, args.noise_sd,
                             args.print_freq, args.outdir)
            test_acc = 'NA'
        elif args.objective in ['classification', 'stability']:
            train_loss = train(train_loader, denoiser, criterion, optimizer,
                               epoch, args.noise_sd, clf)
            if args.dataset == 'imagenet':
                test_loss, test_acc = test_with_classifier(
                    test_loader, denoiser, criterion, args.noise_sd,
                    args.print_freq, clf)
            else:
                # This is needed so that cifar10 denoisers trained using imagenet32 are still evaluated on the cifar10 testset
                test_loss, test_acc = test_with_classifier(
                    cifar10_test_loader, denoiser, criterion, args.noise_sd,
                    args.print_freq, clf)

        after = time.time()

        log(
            logfilename, "{}\t{:.3}\t{:.3}\t{:.3}\t{:.3}\t{:.3}".format(
                epoch, after - before, args.lr, train_loss, test_loss,
                test_acc))

        scheduler.step(epoch)
        args.lr = scheduler.get_lr()[0]

        # Switch from Adam to SGD
        if epoch == args.start_sgd_epoch and args.optimizer == 'AdamThenSGD ':  # Do adam for few steps thaen continue SGD
            print("-->[Switching from Adam to SGD.]")
            args.lr = args.start_sgd_lr
            optimizer = SGD(denoiser.parameters(),
                            lr=args.lr,
                            momentum=args.momentum,
                            weight_decay=args.weight_decay)
            scheduler = StepLR(optimizer,
                               step_size=args.lr_step_size,
                               gamma=args.gamma)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': denoiser.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'checkpoint.pth.tar'))

        if args.objective == 'denoising' and test_loss < best_loss:
            best_loss = test_loss
        elif args.objective in ['classification', 'stability'
                                ] and test_acc > best_acc:
            best_acc = test_acc
        else:
            continue

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': args.arch,
                'state_dict': denoiser.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, os.path.join(args.outdir, 'best.pth.tar'))
コード例 #20
0
class DeepQNetworkOptionAgent:
    def __init__(self,
                 hex_diffusion,
                 option_num,
                 isoption=False,
                 islocal=True,
                 ischarging=True):
        self.learning_rate = 1e-3  # 1e-4
        self.gamma = GAMMA
        self.start_epsilon = START_EPSILON
        self.final_epsilon = FINAL_EPSILON
        self.epsilon_steps = EPSILON_DECAY_STEPS
        self.memory = BatchReplayMemory(256)
        self.batch_size = BATCH_SIZE
        self.clipping_value = CLIPPING_VALUE
        self.input_dim = INPUT_DIM  # 3 input state
        self.relocation_dim = RELOCATION_DIM  # 7
        self.charging_dim = CHARGING_DIM  # 5
        self.option_dim = OPTION_DIM  # 3
        self.output_dim = DQN_OUTPUT_DIM  # 7+5+3 = 15
        self.num_option = option_num
        self.device = torch.device(
            "cuda:0" if torch.cuda.is_available() else "cpu")
        self.path = OPTION_DQN_SAVE_PATH
        self.state_feature_constructor = FeatureConstructor()

        # init higher level DQN network
        self.q_network = DQN_network(self.input_dim, self.output_dim)
        self.target_q_network = DQN_target_network(self.input_dim,
                                                   self.output_dim)
        self.optimizer = torch.optim.Adam(self.q_network.parameters(),
                                          lr=self.learning_rate)
        self.lr_scheduler = StepLR(optimizer=self.optimizer,
                                   step_size=1000,
                                   gamma=0.99)  # 1.79 e-6 at 0.5 million step.
        self.train_step = 0
        # self.load_network()
        self.q_network.to(self.device)
        self.target_q_network.to(self.device)

        self.decayed_epsilon = self.start_epsilon
        # init option network
        self.record_list = []
        self.global_state_dict = OrderedDict()
        self.time_interval = int(0)
        self.global_state_capacity = 5 * 1440  # we store 5 days' global states to fit replay buffer size.
        self.with_option = isoption
        self.with_charging = ischarging
        self.local_matching = islocal
        self.hex_diffusion = hex_diffusion

        self.h_network_list = []
        self.load_option_networks(self.num_option)
        self.middle_terminal = self.init_terminal_states()

    # def load_network(self, RESUME = False):
    #     if RESUME:
    #         lists = os.listdir(self.path)
    #         lists.sort(key=lambda fn: os.path.getmtime(self.path + "/" + fn))
    #         newest_file = os.path.join(self.path, lists[-1])
    #         path_checkpoint = newest_file
    #         checkpoint = torch.load(path_checkpoint)
    #
    #         self.q_network.load_state_dict(checkpoint['net'])
    #         self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
    #
    #         self.train_step = checkpoint['step']
    #         self.copy_parameter()
    #         # self.optimizer.load_state_dict(checkpoint['optimizer'])
    #         print('Successfully load saved network starting from {}!'.format(str(self.train_step)))

    def load_option_networks(self, option_num):
        for option_net_id in range(option_num):
            h_network = OptionNetwork(self.input_dim, 1 + 6 + 5)
            checkpoint = torch.load(
                H_AGENT_SAVE_PATH + 'ht_network_option_%d_1_0_1_11520.pkl' %
                (option_net_id)
            )  # lets try the saved networks after the 14th day.
            h_network.load_state_dict(checkpoint['net'])  # , False
            self.h_network_list.append(h_network.to(self.device))
            print(
                'Successfully load H network {}, total option network num is {}'
                .format(option_net_id, len(self.h_network_list)))

    def init_terminal_states(self):
        """
        we initial a dict to check the sets of terminal hex ids by hour by option id
        :param oid: ID for option network
        :return:
        """
        middle_terminal = defaultdict(list)
        for oid in range(self.num_option):
            with open(TERMINAL_STATE_SAVE_PATH + 'term_states_%d.csv' % oid,
                      'r') as ts:
                next(ts)
                for lines in ts:
                    line = lines.strip().split(',')
                    hr, hid = line  # option_network_id, hour, hex_ids in terminal state
                    middle_terminal[(oid, int(hr))].append(hid)
        return middle_terminal

    def get_actions(self, states, num_valid_relos, assigned_option_ids,
                    global_state):
        """
        option_ids is at the first three slots in the action space, so action id <3 means the corresponding h_network id
        :param global_states:
        :param states: tuple of (tick, hex_id, SOC) and SOC is 0 - 100%
        :param num_valid_relos: only relocation to ADJACENT hexes / charging station is valid
        :states:
        :return: action ids ranges from (0,14) , converted action ids has converted the option ids to hte action ids that are selected by corresponding option networks
        """
        with torch.no_grad():
            self.decayed_epsilon = max(
                self.final_epsilon,
                (self.start_epsilon - self.train_step *
                 (self.start_epsilon - self.final_epsilon) /
                 self.epsilon_steps))

            state_reps = np.array([
                self.state_feature_constructor.construct_state_features(state)
                for state in states
            ])
            hex_diffusions = np.array([
                np.tile(self.hex_diffusion[state[1]], (1, 1, 1))
                for state in states
            ])  # state[1] is hex_id

            mask = self.get_action_mask(
                states,
                num_valid_relos)  # mask for unreachable primitive actions

            option_mask = self.get_option_mask(
                states
            )  # if the state is considered as terminal, we dont use it..
            # terminate_option_mask = torch.from_numpy(option_mask).to(dtype=torch.bool, device=self.device) # the DQN need a tensor as input, so convert it.

            if True:
                full_action_values = np.random.random(
                    (len(states), self.output_dim
                     ))  # generate a matrix with values from 0 to 1
                for i, state in enumerate(states):
                    if assigned_option_ids[i] != -1:
                        full_action_values[i][assigned_option_ids[
                            i]] = 10  # a large enough number to maintain that option if it's terminal state, we next mask it with -1.
                    full_action_values[i][:self.option_dim] = np.negative(
                        option_mask[i, :self.option_dim]
                    )  # convert terminal agents to -1
                    full_action_values[i][(
                        self.option_dim + num_valid_relos[i]):(
                            self.option_dim + self.relocation_dim
                        )] = -1  # mask unreachable neighbors.
                    if state[-1] > HIGH_SOC_THRESHOLD:
                        full_action_values[i][(
                            self.option_dim + self.relocation_dim
                        ):] = -1  # no charging, must relocate
                    elif state[-1] < LOW_SOC_THRESHOLD:
                        full_action_values[i][:(
                            self.option_dim + self.relocation_dim
                        )] = -1  # no relocation, must charge
                action_indexes = np.argmax(full_action_values, 1).tolist()
                # # hard inplace the previously assigned options.
                # action_indexes[np.where(assigned_option_ids!=-1)] = assigned_option_ids[np.where(assigned_option_ids!=-1)]
            # after getting all action ids by DQN, we convert the ones triggered options to the primitive action ids.
            converted_action_indexes = self.convert_option_to_primitive_action_id(
                action_indexes, state_reps, global_state, hex_diffusions, mask)

        return np.array(action_indexes
                        ), np.array(converted_action_indexes) - self.option_dim

    def convert_option_to_primitive_action_id(self, action_indexes, state_reps,
                                              global_state, hex_diffusions,
                                              mask):
        """
        we convert the option ids, e.g., 0,1,2 for each H network, to the generated primitive action ids
        :param action_indexes:
        :param state_reps:
        :param global_state:
        :param hex_diffusions:
        :param mask:
        :return:
        """
        ids_require_option = defaultdict(list)
        for id, action_id in enumerate(action_indexes):
            if action_id < self.num_option:
                ids_require_option[action_id].append(id)
        for option_id in range(self.num_option):
            if ids_require_option[option_id]:
                full_option_values = self.h_network_list[option_id].forward(
                    torch.from_numpy(
                        state_reps[ids_require_option[option_id]]).to(
                            dtype=torch.float32, device=self.device),
                    torch.from_numpy(
                        np.concatenate([
                            np.tile(
                                global_state,
                                (len(ids_require_option[option_id]), 1, 1, 1)),
                            hex_diffusions[ids_require_option[option_id]]
                        ],
                                       axis=1)).to(dtype=torch.float32,
                                                   device=self.device))
                # here mask is of batch x 15 dimension, we omit the first 3 columns, which should be options.
                primitive_action_mask = mask[
                    ids_require_option[option_id], self.
                    option_dim:]  # only primitive actions in option generator
                full_option_values[primitive_action_mask] = -9e10
                option_generated_premitive_action_ids = torch.argmax(
                    full_option_values, dim=1).tolist(
                    )  # let option network select primitive action
                action_indexes[ids_require_option[
                    option_id]] = option_generated_premitive_action_ids + self.option_dim  # 12 to 15
                # cover the option id with the generated primitive action id
        return action_indexes

    def add_global_state_dict(self, global_state_list):
        for tick in global_state_list.keys():
            if tick not in self.global_state_dict.keys():
                self.global_state_dict[tick] = global_state_list[tick]
        if len(self.global_state_dict.keys(
        )) > self.global_state_capacity:  #capacity limit for global states
            for _ in range(
                    len(self.global_state_dict.keys()) -
                    self.global_state_capacity):
                self.global_state_dict.popitem(last=False)

    def add_transition(self, state, action, next_state, reward, terminate_flag,
                       time_steps, valid_action):
        self.memory.push(state, action, next_state, reward, terminate_flag,
                         time_steps, valid_action)

    def batch_sample(self):
        samples = self.memory.sample(
            self.batch_size)  # random.sample(self.memory, self.batch_size)
        return samples
        # state, action, next_state, reward = zip(*samples)
        # return state, action, next_state, reward

    def get_main_Q(self, local_state, global_state):
        return self.q_network.forward(local_state, global_state)

    def get_target_Q(self, local_state, global_state):
        return self.target_q_network.forward(local_state, global_state)

    def copy_parameter(self):
        self.target_q_network.load_state_dict(self.q_network.state_dict())

    def soft_target_update(self, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        for target_param, local_param in zip(
                self.target_q_network.parameters(),
                self.q_network.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)

    def train(self, record_hist):
        self.train_step += 1
        if len(self.memory) < self.batch_size:
            print('batches in replay buffer is {}'.format(len(self.memory)))
            return

        transitions = self.batch_sample()
        batch = self.memory.Transition(*zip(*transitions))

        global_state_reps = [
            self.global_state_dict[int(state[0] / 60)] for state in batch.state
        ]  # should be list of np.array

        global_next_state_reps = [
            self.global_state_dict[int(state_[0] / 60)]
            for state_ in batch.next_state
        ]  # should be list of np.array

        state_reps = [
            self.state_feature_constructor.construct_state_features(state)
            for state in batch.state
        ]
        next_state_reps = [
            self.state_feature_constructor.construct_state_features(state_)
            for state_ in batch.next_state
        ]

        hex_diffusion = [
            np.tile(self.hex_diffusion[state[1]], (1, 1, 1))
            for state in batch.state
        ]
        hex_diffusion_ = [
            np.tile(self.hex_diffusion[state_[1]], (1, 1, 1))
            for state_ in batch.next_state
        ]

        state_batch = torch.from_numpy(np.array(state_reps)).to(
            dtype=torch.float32, device=self.device)
        action_batch = torch.from_numpy(np.array(
            batch.action)).unsqueeze(1).to(dtype=torch.int64,
                                           device=self.device)
        reward_batch = torch.from_numpy(np.array(
            batch.reward)).unsqueeze(1).to(dtype=torch.float32,
                                           device=self.device)
        time_step_batch = torch.from_numpy(np.array(
            batch.time_steps)).unsqueeze(1).to(dtype=torch.float32,
                                               device=self.device)

        next_state_batch = torch.from_numpy(np.array(next_state_reps)).to(
            device=self.device, dtype=torch.float32)
        global_state_batch = torch.from_numpy(
            np.concatenate(
                [np.array(global_state_reps),
                 np.array(hex_diffusion)], axis=1)).to(dtype=torch.float32,
                                                       device=self.device)
        global_next_state_batch = torch.from_numpy(
            np.concatenate(
                [np.array(global_next_state_reps),
                 np.array(hex_diffusion_)],
                axis=1)).to(dtype=torch.float32, device=self.device)

        q_state_action = self.get_main_Q(state_batch,
                                         global_state_batch).gather(
                                             1, action_batch.long())
        # add a mask
        all_q_ = self.get_target_Q(next_state_batch, global_next_state_batch)
        option_mask = self.get_option_mask(batch.next_state)
        mask_ = self.get_action_mask(
            batch.next_state,
            batch.valid_action_num)  # action mask for next state
        all_q_[option_mask] = -9e10
        all_q_[mask_] = -9e10
        maxq = all_q_.max(1)[0].detach().unsqueeze(1)
        y = reward_batch + maxq * torch.pow(self.gamma, time_step_batch)
        loss = F.smooth_l1_loss(q_state_action, y)
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.q_network.parameters(),
                                 self.clipping_value)
        self.optimizer.step()
        self.lr_scheduler.step()

        self.record_list.append([
            self.train_step,
            round(float(loss), 3),
            round(float(reward_batch.view(-1).mean()), 3)
        ])
        self.save_parameter(record_hist)
        print(
            'Training step is {}; Learning rate is {}; Epsilon is {}:'.format(
                self.train_step, self.lr_scheduler.get_lr(),
                round(self.decayed_epsilon, 4)))

    def get_action_mask(self, batch_state, batch_valid_action):
        """
        the action space: the first 3 is for h_network slots, then 7 relocation actions,and 5 nearest charging stations.
        :param batch_state: state
        :param batch_valid_action: info that limites to relocate to reachable neighboring hexes
        :return:
        """
        mask = np.zeros((len(batch_state), self.output_dim))  # (num_state, 15)
        for i, state in enumerate(batch_state):
            mask[i][(self.option_dim + batch_valid_action[i]):(
                self.option_dim + self.relocation_dim
            )] = 1  # limited to relocate to reachable neighboring hexes
            if state[-1] > HIGH_SOC_THRESHOLD:
                mask[i][(
                    self.option_dim +
                    self.relocation_dim):] = 1  # no charging, must relocate
            elif state[-1] < LOW_SOC_THRESHOLD:
                mask[i][:(
                    self.option_dim +
                    self.relocation_dim)] = 1  # no relocation, must charge

        mask = torch.from_numpy(mask).to(dtype=torch.bool, device=self.device)
        return mask

    def get_option_mask(self, states):
        """
        self.is_terminal is to judge if the state is terminal state with the info of hour and hex_id
        :param states:
        :return:
        """
        terminate_option_mask = np.zeros((len(states), self.output_dim))
        for oid in range(self.num_option):
            terminate_option_mask[:, oid] = self.is_terminal(
                states, oid)  # set as 0 if not in terminal set
        for oid in range(self.num_option, self.option_dim):
            terminate_option_mask[:, oid] = 1  # mask out empty options
        return terminate_option_mask

    def is_terminal(self, states, oid):
        """

        :param states:
        :return: a list of bool
        """
        return [
            1 if state in self.middle_terminal[(oid,
                                                int(state[0] // (60 * 60) %
                                                    24))] else 0
            for state in states
        ]

    def is_initial(self, states, oid):
        """

        :param states:
        :return: a list of bool
        """
        return [
            1 if state not in self.middle_terminal[(oid,
                                                    int(state[0] // (60 * 60) %
                                                        24))] else 0
            for state in states
        ]

    def save_parameter(self, record_hist):
        # torch.save(self.q_network.state_dict(), self.dqn_path)
        if self.train_step % SAVING_CYCLE == 0:
            checkpoint = {
                "net": self.q_network.state_dict(),
                # 'optimizer': self.optimizer.state_dict(),
                "step": self.train_step,
                "lr_scheduler": self.lr_scheduler.state_dict()
            }
            if not os.path.isdir(self.path):
                os.mkdir(self.path)
            # print('the path is {}'.format('logs/dqn_model/duel_dqn_%s.pkl'%(str(self.train_step))))
            torch.save(
                checkpoint,
                'logs/test/cnn_dqn_model/dqn_with_option_%d_%d_%d_%d_%s.pkl' %
                (self.num_option, bool(self.with_option),
                 bool(self.with_charging), bool(
                     self.local_matching), str(self.train_step)))
            # record training process (stacked before)
            for item in self.record_list:
                record_hist.writelines('{},{},{}\n'.format(
                    item[0], item[1], item[2]))
            print(
                'Training step: {}, replay buffer size:{}, epsilon: {}, learning rate: {}'
                .format(self.record_list[-1][0], len(self.memory),
                        self.decayed_epsilon, self.lr_scheduler.get_lr()))
            self.record_list = []
コード例 #21
0
def train():

    forward_times = counter_t = vis_count = 0

    dataset = BraTS_FLAIR(csv_dir,
                          hgg_dir,
                          transform=None,
                          train_size=train_size)
    dataset_val = BraTS_FLAIR_val(csv_dir, hgg_dir)  #val
    data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2)
    val_loader = DataLoader(dataset_val,
                            batch_size_val,
                            shuffle=True,
                            num_workers=2)
    loaders = {'train': data_loader, 'val': val_loader}

    top_models = [(0, 10000)] * 5  # (epoch,loss)
    worst_val = 10000  # init val save loop

    loss_fn = torch.nn.CrossEntropyLoss(weight=loss_weights)
    soft_max = torch.nn.Softmax(dim=1)

    opt = torch.optim.Adam(model.parameters(), lr=init_lr)
    scheduler = StepLR(opt, step_size=opt_step_size, gamma=opt_gamma)  #**
    # multisteplr 0.5,

    opt.zero_grad()

    for epoch in range(epochs):
        for e in loaders:
            if e == 'train':
                counter_t += 1
                model.train()
                grad = True  #
            else:
                model.eval()
                grad = False

            with torch.set_grad_enabled(grad):
                for idx, batch_data in enumerate(loaders[e]):
                    batch_input = Variable(batch_data['img'].float()).cuda()
                    batch_gt_mask = Variable(batch_data['mask'].float()).cuda()
                    batch_seg_orig = Variable(batch_data['seg_orig']).cuda()

                    pred_mask = model(batch_input)
                    if e == 'train': forward_times += 1

                    ce = loss_fn(pred_mask, batch_seg_orig.long())
                    soft_mask = soft_max(pred_mask)
                    dice = dice_loss(soft_mask, batch_gt_mask)

                    a, b, c, d = dice_loss_classes(soft_mask, batch_gt_mask)

                    loss = ce + (a + b + c + d) / 4  # +dice

                    print('Dice Losses: ', a.item(), b.item(), c.item(),
                          d.item())
                    if e == 'train':
                        Lc.append(ce.item())
                        cross_moving_avg = sum(Lc) / len(Lc)
                        Ldc.append(
                            np.array([a.item(),
                                      b.item(),
                                      c.item(),
                                      d.item()]))
                        Ld.append(dice.item())
                        dice_moving_avg = sum(Ld) / len(Ld)
                        L.append(loss.item())
                        loss_moving_avg = sum(L) / len(L)
                        loss.backward()
                        print('Epoch: ', epoch + 1, ' Batch: ', idx + 1,
                              ' lr: ',
                              scheduler.get_lr()[-1], ' Dice: ',
                              dice_moving_avg, ' CE: ', cross_moving_avg,
                              ' Loss :', loss_moving_avg)
                        if forward_times == grad_accu_times:
                            opt.step()
                            opt.zero_grad()
                            forward_times = 0
                            print('\nUpdate weights ... \n')

                        writer.add_scalar('Train Loss', loss.item(), counter_t)
                        writer.add_scalar('Train CE', ce.item(), counter_t)
                        writer.add_scalar('Train Dice', dice.item(), counter_t)
                        writer.add_scalar('D1', a.item(), counter_t)
                        writer.add_scalar('D2', b.item(), counter_t)
                        writer.add_scalar('D3', c.item(), counter_t)

                        writer.add_scalar('D4', d.item(), counter_t)
                        writer.add_scalar('Lr',
                                          scheduler.get_lr()[-1], counter_t)

                        # vis(soft_mask,batch_seg_orig,vis_count,mode='train')
                        # vis_count+=25 # += no of images in vis loop
                        scheduler.step()

                    else:
                        Lv.append(loss.item())
                        Lvd.append(dice.item())
                        Lvc.append(ce.item())
                        lv_avg = sum(Lv) / len(Lv)
                        lvd_avg = sum(Lvd) / len(Lvd)
                        lvc_avg = sum(Lvc) / len(Lvc)
                        Lvdc.append(
                            np.array([a.item(),
                                      b.item(),
                                      c.item(),
                                      d.item()]))
                        Ldc.append(
                            np.array([a.item(),
                                      b.item(),
                                      c.item(),
                                      d.item()]))

                        writer.add_scalar('Val Loss', loss.item(), counter_t)
                        writer.add_scalar('Val CE', ce.item(), counter_t)
                        writer.add_scalar('Val Dice', dice.item(), counter_t)
                        writer.add_scalar('D1v', a.item(), counter_t)
                        writer.add_scalar('D2v', b.item(), counter_t)
                        writer.add_scalar('D3v', c.item(), counter_t)
                        writer.add_scalar('D4', d.item(), counter_t)
                        # vis(soft_mask,batch_seg_orig,vis_count,mode='val')
                        print(save_initial)

                        print('Validation total Loss::::::::::',
                              round(loss.item(), 3))
                        del batch_input, batch_gt_mask, batch_seg_orig, pred_mask

                        print('current n worst val: ', round(loss.item(), 2),
                              worst_val)

                        if epoch > 55 and epoch % 15 == 0:  # save every 15- for logs- confilicting with down

                            checkpoint = {
                                'epoch': epoch + 1,
                                'moving loss': L,
                                'dice': Ld,
                                'val': Lv,
                                'valc': Lvc,
                                'vald': Lvd,
                                'cross el': Lc,
                                'Ldvc': Lvdc,
                                'Ldc': Ldc,
                                'state_dict': model.state_dict(),
                                'optimizer': opt.state_dict()
                            }
                            torch.save(
                                checkpoint, save_initial + '-' +
                                str(round(loss, 2)) + '|' + str(epoch + 1))
                            print(
                                'Saved at 25 : ', save_initial + '-' +
                                str(round(loss, 2)) + '|' + str(epoch + 1))

                        if loss < worst_val:
                            # print('saving --------------------------------------',epoch)
                            top_models = sorted(
                                top_models,
                                key=lambda x: x[1])  # sort maybe not needed

                            checkpoint = {
                                'epoch': epoch + 1,
                                'moving loss': L,
                                'dice': Ld,
                                'val': Lv,
                                'valc': Lvc,
                                'vald': Lvd,
                                'cross el': Lc,
                                'Ldvc': Lvdc,
                                'Ldc': Ldc,
                                'state_dict': model.state_dict(),
                                'optimizer': opt.state_dict()
                            }
                            torch.save(
                                checkpoint, save_initial + '-' +
                                str(round(loss, 2)) + '|' + str(epoch + 1))

                            to_be_deleted = save_initial + '-' + str(
                                round(top_models[-1][1], 2)) + '|' + str(
                                    top_models[-1][0])  # ...loss|epoch
                            # print(to_be_deleted)

                            top_models.append((epoch + 1, loss.item()))

                            top_models = sorted(top_models, key=lambda x: x[
                                1])  #sort after addition of new val
                            top_models.pop(-1)

                            print('top_models', top_models)

                            worst_val = top_models[-1][1]
                            if str(
                                    to_be_deleted
                            ) != save_initial + '-' + '10000.0|0':  # first 5 epoch will be saved and no deletion this point
                                os.remove(to_be_deleted)
                                # print('sucess deleted------------------')

                        break
コード例 #22
0
loss_history = []
for epoch in range(num_epochs):
    for data in dataloader:
        img = data
        img = img.view(img.size(0), -1)
        img = Variable(img).to(device)
        # ===================forward=====================
        output = model(img)
        loss = criterion(output, img)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}, lr:{:.9f}'.format(
        epoch + 1, num_epochs, loss.item(),
        scheduler.get_lr()[0]))  #loss.data[0]
    # if epoch % 10 == 0:
    #     pic = output.cpu().data #to_img(output.cpu().data)
    #     save_image(pic, './simp_aut_img/image_{}.png'.format(epoch))
    scheduler.step()

    #Plot history
    loss_history.append(loss.item())
    plt.plot(loss_history)
    plt.xlabel('epochs')
    plt.ylabel('loss')

#torch.save(model.state_dict(), './sim_autoencoder.pth')
コード例 #23
0
ファイル: main.py プロジェクト: Flyfoxs/DFL-CNN
def main(_config):
    print('DFL-CNN <==> Part1 : prepare for parameters <==> Begin')
    global args, best_prec1
    args = edict(_config)

    print('DFL-CNN <==> Part1 : prepare for parameters <==> Done')

    print('DFL-CNN <==> Part2 : Load Network  <==> Begin')
    model = DFL_VGG16(k=10, nclass=200)
    if args.gpu is not None:
        model = nn.DataParallel(model, device_ids=range(args.gpu))
        model = model.cuda()
        cudnn.benchmark = True
    if args.init_type is not None:
        try:
            init_weights(model, init_type=args.init_type)
        except:
            sys.exit(
                'DFL-CNN <==> Part2 : Load Network  <==> Init_weights error!')

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay)

    # Optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print(
                'DFL-CNN <==> Part2 : Load Network  <==> Continue from {} epoch {}'
                .format(args.resume, checkpoint['epoch']))
        else:
            print('DFL-CNN <==> Part2 : Load Network  <==> Failed')
    print('DFL-CNN <==> Part2 : Load Network  <==> Done')

    print('DFL-CNN <==> Part3 : Load Dataset  <==> Begin')
    # dataroot = os.path.abspath(args.dataroot)
    # traindir = os.path.join(dataroot, 'train')
    # testdir = os.path.join(dataroot, 'test')

    # ImageFolder to process img
    transform_train = get_transform_for_train()
    transform_test = get_transform_for_test()
    transform_test_simple = get_transform_for_test_simple()

    train_dataset = Cub2011(train=True, transform=transform_train)
    test_dataset = Cub2011(train=False, transform=transform_test)
    test_dataset_simple = Cub2011(train=False, transform=transform_test_simple)

    # A list for target to classname
    index2classlist = train_dataset.index2classlist()

    # data loader
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=args.gpu *
                                               args.train_batchsize_per_gpu,
                                               shuffle=True,
                                               num_workers=args.workers,
                                               pin_memory=True,
                                               drop_last=False)
    test_loader = torch.utils.data.DataLoader(test_dataset,
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=args.workers,
                                              pin_memory=True,
                                              drop_last=False)
    test_loader_simple = torch.utils.data.DataLoader(test_dataset_simple,
                                                     batch_size=1,
                                                     shuffle=True,
                                                     num_workers=args.workers,
                                                     pin_memory=True,
                                                     drop_last=False)
    print('DFL-CNN <==> Part3 : Load Dataset  <==> Done')

    print('DFL-CNN <==> Part4 : Train and Test  <==> Begin')
    tbar = tqdm(range(args.start_epoch, args.epochs))
    for epoch in tbar:

        scheduler = StepLR(optimizer, step_size=5, gamma=0.8)
        top1, top5, losses = train(args, train_loader, model, criterion,
                                   optimizer, epoch)
        scheduler.step()

        # evaluate on validation set
        if epoch % args.eval_epoch == 0:
            prec1, prec5 = validate_simple(args, test_loader_simple, model,
                                           criterion, epoch)

            ex.log_scalar('trn_top1', top1, epoch)
            ex.log_scalar('trn_top5', top5, epoch)
            ex.log_scalar('val_top1', prec1, epoch)
            ex.log_scalar('val_top5', prec5, epoch)

            # remember best prec@1 and save checkpoint
            is_best = prec1 > best_prec1
            best_prec1 = max(prec1, best_prec1)
            save_checkpoint(
                {
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'best_prec1': best_prec1,
                    'optimizer': optimizer.state_dict(),
                    'prec1': prec1,
                }, is_best)

        # do a test for visualization
        if epoch % args.vis_epoch == 0 and epoch != 0:
            draw_patch(epoch, model, index2classlist, args)

        lr = scheduler.get_lr()
        tbar.set_postfix(train_top1=top1,
                         val_top1=prec1,
                         train_top5=top5,
                         val_prec5=prec5,
                         epoch=epoch,
                         lr=lr,
                         refresh=False)
コード例 #24
0
    loss1_history.append(running_loss1_mean / episode_timestep)
    loss2_history.append(running_loss2_mean / episode_timestep)
    running_loss1_mean = 0
    running_loss2_mean = 0

    avg_reward += episode_reward
    avg_timestep += episode_timestep

    avg_history['episodes'].append(episode_i + 1)
    avg_history['timesteps'].append(avg_timestep)
    avg_history['reward'].append(avg_reward)
    avg_timestep = 0
    avg_reward = 0.0

    if (episode_i + 1) % agg_interval == 0:
        print('Episode : ', episode_i + 1, 'Learning Rate', scheduler.get_lr(),
              'Loss : ', loss2_history[-1], 'Avg Timestep : ',
              avg_history['timesteps'][-1])

# In[]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(12, 7))
plt.subplots_adjust(wspace=0.5)
axes[0][0].plot(avg_history['episodes'], avg_history['timesteps'])
axes[0][0].set_title('Timesteps per episode')
axes[0][0].set_ylabel('Timesteps')
axes[0][1].plot(avg_history['episodes'], avg_history['reward'])
axes[0][1].set_title('Reward per episode')
axes[0][1].set_ylabel('Reward')
axes[1][1].set_title('Actor Objective')
コード例 #25
0
def main():
    print("init data folders")
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
    encoder_lv1 = models.Encoder()
    encoder_lv2 = models.Encoder()
    encoder_lv3 = models.Encoder()
    encoder_lv4 = models.Encoder()

    decoder_lv1 = models.Decoder()
    decoder_lv2 = models.Decoder()
    decoder_lv3 = models.Decoder()
    decoder_lv4 = models.Decoder()

    encoder_lv1.apply(weight_init).cuda(GPU)
    encoder_lv2.apply(weight_init).cuda(GPU)
    encoder_lv3.apply(weight_init).cuda(GPU)
    encoder_lv4.apply(weight_init).cuda(GPU)

    decoder_lv1.apply(weight_init).cuda(GPU)
    decoder_lv2.apply(weight_init).cuda(GPU)
    decoder_lv3.apply(weight_init).cuda(GPU)
    decoder_lv4.apply(weight_init).cuda(GPU)

    encoder_lv1_optim = RAdam(encoder_lv1.parameters(), lr=LEARNING_RATE)
    encoder_lv1_scheduler = StepLR(encoder_lv1_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv2_optim = RAdam(encoder_lv2.parameters(), lr=LEARNING_RATE)
    encoder_lv2_scheduler = StepLR(encoder_lv2_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv3_optim = RAdam(encoder_lv3.parameters(), lr=LEARNING_RATE)
    encoder_lv3_scheduler = StepLR(encoder_lv3_optim,
                                   step_size=1000,
                                   gamma=0.1)
    encoder_lv4_optim = RAdam(encoder_lv4.parameters(), lr=LEARNING_RATE)
    encoder_lv4_scheduler = StepLR(encoder_lv4_optim,
                                   step_size=1000,
                                   gamma=0.1)

    decoder_lv1_optim = RAdam(decoder_lv1.parameters(), lr=LEARNING_RATE)
    decoder_lv1_scheduler = StepLR(decoder_lv1_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv2_optim = RAdam(decoder_lv2.parameters(), lr=LEARNING_RATE)
    decoder_lv2_scheduler = StepLR(decoder_lv2_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv3_optim = RAdam(decoder_lv3.parameters(), lr=LEARNING_RATE)
    decoder_lv3_scheduler = StepLR(decoder_lv3_optim,
                                   step_size=1000,
                                   gamma=0.1)
    decoder_lv4_optim = RAdam(decoder_lv4.parameters(), lr=LEARNING_RATE)
    decoder_lv4_scheduler = StepLR(decoder_lv4_optim,
                                   step_size=1000,
                                   gamma=0.1)

    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")):
        encoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")):
        encoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv2.pkl")))
        print("load encoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")):
        encoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv3.pkl")))
        print("load encoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")):
        encoder_lv4.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/encoder_lv4.pkl")))
        print("load encoder_lv4 success")

    # for param in decoder_lv4.layer24.parameters():
    #     param.requires_grad = False
    # for param in encoder_lv3.parameters():
    #     param.requires_grad = False
    #     # print("检查部分参数是否固定......")
    #     print(encoder_lv3.layer1.bias.requires_grad)
    # for param in decoder_lv3.parameters():
    #     param.requires_grad = False
    # for param in encoder_lv2.parameters():
    #     param.requires_grad = False
    #     # print("检查部分参数是否固定......")
    #     print(encoder_lv2.layer1.bias.requires_grad)
    # for param in decoder_lv2.parameters():
    #     param.requires_grad = False

    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")):
        decoder_lv1.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv1.pkl")))
        print("load encoder_lv1 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")):
        decoder_lv2.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv2.pkl")))
        print("load decoder_lv2 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")):
        decoder_lv3.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv3.pkl")))
        print("load decoder_lv3 success")
    if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")):
        decoder_lv4.load_state_dict(
            torch.load(str('./checkpoints/' + METHOD + "/decoder_lv4.pkl")))
        print("load decoder_lv4 success")

    if os.path.exists('./checkpoints/' + METHOD) == False:
        os.system('mkdir ./checkpoints/' + METHOD)

    for epoch in range(args.start_epoch, EPOCHS):
        epoch += 1

        print("Training...")
        print('lr:', encoder_lv1_scheduler.get_lr())

        train_dataset = GoProDataset(
            blur_image_files='./datas/GoPro/train_blur_file.txt',
            sharp_image_files='./datas/GoPro/train_sharp_file.txt',
            root_dir='./datas/GoPro',
            crop=True,
            crop_size=IMAGE_SIZE,
            transform=transforms.Compose([transforms.ToTensor()]))

        train_dataloader = DataLoader(train_dataset,
                                      batch_size=BATCH_SIZE,
                                      shuffle=True,
                                      num_workers=8,
                                      pin_memory=True)
        start = 0

        for iteration, images in enumerate(train_dataloader):
            mse = nn.MSELoss().cuda(GPU)

            gt = Variable(images['sharp_image'] - 0.5).cuda(GPU)
            H = gt.size(2)
            W = gt.size(3)

            images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
            images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
            images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
            images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
            images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
            images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
            images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]
            images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :]
            images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :]
            images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :]
            images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :]
            images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :]
            images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :]
            images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :]
            images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :]

            feature_lv4_1 = encoder_lv4(images_lv4_1)
            feature_lv4_2 = encoder_lv4(images_lv4_2)
            feature_lv4_3 = encoder_lv4(images_lv4_3)
            feature_lv4_4 = encoder_lv4(images_lv4_4)
            feature_lv4_5 = encoder_lv4(images_lv4_5)
            feature_lv4_6 = encoder_lv4(images_lv4_6)
            feature_lv4_7 = encoder_lv4(images_lv4_7)
            feature_lv4_8 = encoder_lv4(images_lv4_8)
            feature_lv4_top_left = torch.cat((feature_lv4_1, feature_lv4_2), 2)
            feature_lv4_top_right = torch.cat((feature_lv4_3, feature_lv4_4),
                                              2)
            feature_lv4_bot_left = torch.cat((feature_lv4_5, feature_lv4_6), 2)
            feature_lv4_bot_right = torch.cat((feature_lv4_7, feature_lv4_8),
                                              2)
            feature_lv4_top = torch.cat(
                (feature_lv4_top_left, feature_lv4_top_right), 3)
            feature_lv4_bot = torch.cat(
                (feature_lv4_bot_left, feature_lv4_bot_right), 3)
            feature_lv4 = torch.cat((feature_lv4_top, feature_lv4_bot), 2)
            residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
            residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
            residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
            residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)

            feature_lv3_1 = encoder_lv3(images_lv3_1 + residual_lv4_top_left)
            feature_lv3_2 = encoder_lv3(images_lv3_2 + residual_lv4_top_right)
            feature_lv3_3 = encoder_lv3(images_lv3_3 + residual_lv4_bot_left)
            feature_lv3_4 = encoder_lv3(images_lv3_4 + residual_lv4_bot_right)
            feature_lv3_top = torch.cat(
                (feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
            feature_lv3_bot = torch.cat(
                (feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
            feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
            residual_lv3_top = decoder_lv3(feature_lv3_top)
            residual_lv3_bot = decoder_lv3(feature_lv3_bot)

            feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
            feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
            feature_lv2 = torch.cat(
                (feature_lv2_1, feature_lv2_2), 2) + feature_lv3
            residual_lv2 = decoder_lv2(feature_lv2)

            feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
            deblur_image = decoder_lv1(feature_lv1)

            loss = mse(deblur_image, gt)

            encoder_lv1.zero_grad()
            encoder_lv2.zero_grad()
            encoder_lv3.zero_grad()
            encoder_lv4.zero_grad()

            decoder_lv1.zero_grad()
            decoder_lv2.zero_grad()
            decoder_lv3.zero_grad()
            decoder_lv4.zero_grad()

            loss.backward()

            encoder_lv1_optim.step()
            encoder_lv2_optim.step()
            encoder_lv3_optim.step()
            encoder_lv4_optim.step()

            decoder_lv1_optim.step()
            decoder_lv2_optim.step()
            decoder_lv3_optim.step()
            decoder_lv4_optim.step()

            if (iteration + 1) % 10 == 0:
                stop = time.time()
                print(METHOD + " epoch:", epoch, "iteration:", iteration + 1,
                      "loss:%.4f" % loss.item(), 'time:%.4f' % (stop - start))
                start = time.time()
        encoder_lv1_scheduler.step(epoch)
        encoder_lv2_scheduler.step(epoch)
        encoder_lv3_scheduler.step(epoch)
        encoder_lv4_scheduler.step(epoch)

        decoder_lv1_scheduler.step(epoch)
        decoder_lv2_scheduler.step(epoch)
        decoder_lv3_scheduler.step(epoch)
        decoder_lv4_scheduler.step(epoch)
        if (epoch) % 100 == 0:
            if os.path.exists('./checkpoints/' + METHOD + '/epoch' +
                              str(epoch)) == False:
                os.system('mkdir ./checkpoints/' + METHOD + '/epoch' +
                          str(epoch))

            print("Testing...")
            test_dataset = GoProDataset(
                blur_image_files='./datas/GoPro/test_blur_file.txt',
                sharp_image_files='./datas/GoPro/test_sharp_file.txt',
                root_dir='./datas/GoPro',
                transform=transforms.Compose([transforms.ToTensor()]))
            test_dataloader = DataLoader(test_dataset,
                                         batch_size=1,
                                         shuffle=False,
                                         num_workers=8,
                                         pin_memory=True)
            test_time = 0
            for iteration, images in enumerate(test_dataloader):
                with torch.no_grad():
                    start = time.time()
                    images_lv1 = Variable(images['blur_image'] - 0.5).cuda(GPU)
                    H = images_lv1.size(2)
                    W = images_lv1.size(3)
                    images_lv2_1 = images_lv1[:, :, 0:int(H / 2), :]
                    images_lv2_2 = images_lv1[:, :, int(H / 2):H, :]
                    images_lv3_1 = images_lv2_1[:, :, :, 0:int(W / 2)]
                    images_lv3_2 = images_lv2_1[:, :, :, int(W / 2):W]
                    images_lv3_3 = images_lv2_2[:, :, :, 0:int(W / 2)]
                    images_lv3_4 = images_lv2_2[:, :, :, int(W / 2):W]
                    images_lv4_1 = images_lv3_1[:, :, 0:int(H / 4), :]
                    images_lv4_2 = images_lv3_1[:, :, int(H / 4):int(H / 2), :]
                    images_lv4_3 = images_lv3_2[:, :, 0:int(H / 4), :]
                    images_lv4_4 = images_lv3_2[:, :, int(H / 4):int(H / 2), :]
                    images_lv4_5 = images_lv3_3[:, :, 0:int(H / 4), :]
                    images_lv4_6 = images_lv3_3[:, :, int(H / 4):int(H / 2), :]
                    images_lv4_7 = images_lv3_4[:, :, 0:int(H / 4), :]
                    images_lv4_8 = images_lv3_4[:, :, int(H / 4):int(H / 2), :]

                    feature_lv4_1 = encoder_lv4(images_lv4_1)
                    feature_lv4_2 = encoder_lv4(images_lv4_2)
                    feature_lv4_3 = encoder_lv4(images_lv4_3)
                    feature_lv4_4 = encoder_lv4(images_lv4_4)
                    feature_lv4_5 = encoder_lv4(images_lv4_5)
                    feature_lv4_6 = encoder_lv4(images_lv4_6)
                    feature_lv4_7 = encoder_lv4(images_lv4_7)
                    feature_lv4_8 = encoder_lv4(images_lv4_8)

                    feature_lv4_top_left = torch.cat(
                        (feature_lv4_1, feature_lv4_2), 2)
                    feature_lv4_top_right = torch.cat(
                        (feature_lv4_3, feature_lv4_4), 2)
                    feature_lv4_bot_left = torch.cat(
                        (feature_lv4_5, feature_lv4_6), 2)
                    feature_lv4_bot_right = torch.cat(
                        (feature_lv4_7, feature_lv4_8), 2)

                    feature_lv4_top = torch.cat(
                        (feature_lv4_top_left, feature_lv4_top_right), 3)
                    feature_lv4_bot = torch.cat(
                        (feature_lv4_bot_left, feature_lv4_bot_right), 3)

                    residual_lv4_top_left = decoder_lv4(feature_lv4_top_left)
                    residual_lv4_top_right = decoder_lv4(feature_lv4_top_right)
                    residual_lv4_bot_left = decoder_lv4(feature_lv4_bot_left)
                    residual_lv4_bot_right = decoder_lv4(feature_lv4_bot_right)

                    feature_lv3_1 = encoder_lv3(images_lv3_1 +
                                                residual_lv4_top_left)
                    feature_lv3_2 = encoder_lv3(images_lv3_2 +
                                                residual_lv4_top_right)
                    feature_lv3_3 = encoder_lv3(images_lv3_3 +
                                                residual_lv4_bot_left)
                    feature_lv3_4 = encoder_lv3(images_lv3_4 +
                                                residual_lv4_bot_right)

                    feature_lv3_top = torch.cat(
                        (feature_lv3_1, feature_lv3_2), 3) + feature_lv4_top
                    feature_lv3_bot = torch.cat(
                        (feature_lv3_3, feature_lv3_4), 3) + feature_lv4_bot
                    residual_lv3_top = decoder_lv3(feature_lv3_top)
                    residual_lv3_bot = decoder_lv3(feature_lv3_bot)

                    feature_lv2_1 = encoder_lv2(images_lv2_1 +
                                                residual_lv3_top)
                    feature_lv2_2 = encoder_lv2(images_lv2_2 +
                                                residual_lv3_bot)
                    feature_lv2 = torch.cat(
                        (feature_lv2_1, feature_lv2_2), 2) + torch.cat(
                            (feature_lv3_top, feature_lv3_bot), 2)
                    residual_lv2 = decoder_lv2(feature_lv2)

                    feature_lv1 = encoder_lv1(images_lv1 +
                                              residual_lv2) + feature_lv2
                    deblur_image = decoder_lv1(feature_lv1)
                    stop = time.time()
                    test_time += stop - start
                    print(
                        'RunTime:%.4f' % (stop - start),
                        '  Average Runtime:%.4f' % (test_time /
                                                    (iteration + 1)))
                    save_deblur_images(deblur_image.data + 0.5, iteration,
                                       epoch)
                    #
                    torch.save(
                        encoder_lv1.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv1.pkl"))
                    torch.save(
                        encoder_lv2.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv2.pkl"))
                    torch.save(
                        encoder_lv3.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv3.pkl"))
                    torch.save(
                        encoder_lv4.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/encoder_lv4.pkl"))
                    torch.save(
                        decoder_lv1.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv1.pkl"))
                    torch.save(
                        decoder_lv2.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv2.pkl"))
                    torch.save(
                        decoder_lv3.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv3.pkl"))
                    torch.save(
                        decoder_lv4.state_dict(),
                        str('./checkpoints/' + METHOD + '/epoch' + str(epoch) +
                            "/decoder_lv4.pkl"))

        torch.save(encoder_lv1.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv1.pkl"))
        torch.save(encoder_lv2.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv2.pkl"))
        torch.save(encoder_lv3.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv3.pkl"))
        torch.save(encoder_lv4.state_dict(),
                   str('./checkpoints/' + METHOD + "/encoder_lv4.pkl"))
        torch.save(decoder_lv1.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv1.pkl"))
        torch.save(decoder_lv2.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv2.pkl"))
        torch.save(decoder_lv3.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv3.pkl"))
        torch.save(decoder_lv4.state_dict(),
                   str('./checkpoints/' + METHOD + "/decoder_lv4.pkl"))
コード例 #26
0
            loss = vade.module.ELBO_Loss(x)

            opti.zero_grad()
            loss.backward()
            opti.step()

            L += loss.detach().cpu().numpy()

        pre = []
        tru = []

        with torch.no_grad():
            for x, y in DL:
                if args.cuda:
                    x = x.cuda()

                tru.append(y.numpy())
                pre.append(vade.module.predict(x))

        tru = np.concatenate(tru, 0)
        pre = np.concatenate(pre, 0)

        writer.add_scalar('loss', L / len(DL), epoch)
        writer.add_scalar('acc', cluster_acc(pre, tru)[0] * 100, epoch)
        writer.add_scalar('lr', lr_s.get_lr()[0], epoch)

        epoch_bar.write('Loss={:.4f},ACC={:.4f}%,LR={:.4f}'.format(
            L / len(DL),
            cluster_acc(pre, tru)[0] * 100,
            lr_s.get_lr()[0]))
コード例 #27
0
    def train_model(self,
                    epochs=100,
                    batch_size=10,
                    lr=0.002,
                    gamma=0.95,
                    cutoff_epoch=10):
        """Trains the model for given epochs and batch_size and saves model at each epoch
            PARAMS:
                epochs (int): number of training rounds on the data
                batch_size (int): silmultaneous images to train on
                lr (float): learning rate for Adam optimizer
                gamma (float): geometric decrease rate for LRScheduler
                cutoff_epoch (int): epoch after which the lr starts to decrease
        """

        # Better numerical stability than BCELoss(sigmoid)
        criterion = nn.BCEWithLogitsLoss()
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      self.parameters()),
                               lr=lr,
                               betas=(0.9, 0.999))

        scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

        print("Training on {:d} images...".format(
            len(self.data_loader.train_images)))
        print("\r[{:0>2d}/{}]{}  loss: {:.6}".format(0, epochs,
                                                     progress_bar(0, epochs),
                                                     0.0),
              end=' ' * 10)

        losses = [[] for _ in range(4)]

        for e in range(epochs):
            print('LR: {:.7f}'.format(scheduler.get_lr()[0]), end=' ' * 10)

            for batch_i, (train_input, train_target) in enumerate(
                    self.data_loader.load_batch(batch_size)):

                # Gets data to the model's device
                train_input = torch.from_numpy(train_input).to(
                    device=self.device)
                train_target = torch.from_numpy(train_target).to(
                    device=self.device)

                output = self.forward(train_input)

                loss = criterion(output, train_target)
                l1_reg = output.norm(1)

                l2_reg = torch.tensor(0).float().to(self.device)
                for param in self.parameters():
                    l2_reg += param.norm(2)

                optimizer.zero_grad()
                (loss + self.l2_reg_weight * l2_reg +
                 self.l1_reg_weight * l1_reg).backward()
                optimizer.step()

                if not batch_i % 5:
                    print(
                        "\r[{:0>2d}/{}]{}  loss: {:.6f}  w/ reg {:.6f}".format(
                            e + 1, epochs,
                            progress_bar(
                                (e * self.data_loader.n_batches) + batch_i,
                                (epochs - 1) * self.data_loader.n_batches,
                                newline=False), loss.item(),
                            (loss + self.l2_reg_weight * l2_reg +
                             self.l1_reg_weight * l1_reg).item()),
                        end='\t')

                    losses[0].append(e +
                                     (batch_i) / self.data_loader.n_batches)
                    losses[1].append(loss.item())
                    losses[2].append(self.l1_reg_weight * l1_reg.item())
                    losses[3].append(self.l2_reg_weight * l2_reg.item())

            # Generate test images in eval() mode
            self.eval()
            _, axes = plt.subplots(1, 3, figsize=(60, 20))
            axes[0].imshow(train_input.cpu().numpy()[0].squeeze(), cmap='gray')
            axes[0].set_title("train_input")
            axes[1].imshow(
                (self(train_input)).cpu().detach().numpy()[0].squeeze(),
                cmap='gray')
            axes[1].set_title("output")
            axes[2].imshow(train_target.cpu().numpy()[0].squeeze(),
                           cmap='gray')
            axes[2].set_title("train_target")
            plt.savefig("outputs/fig_{:d}.png".format(e + 1))
            plt.close()
            self.train()

            plt.figure(figsize=(20, 20))
            plt.plot(losses[0], losses[1], label="BCE Loss")
            plt.plot(losses[0], losses[2], label="L1 reg")
            plt.plot(losses[0], losses[3], label="L2 reg")
            plt.legend()
            plt.xlabel("Epoch")
            plt.ylabel("Loss ")
            plt.savefig("outputs/loss.png", dpi=1000)
            plt.close()

            if e > cutoff_epoch:
                scheduler.step()

            if e > 0:
                self.save_state_dir('outputs/saved_models',
                                    "{}.pth".format(self.name.lower()))
        else:
            print('')
            with torch.no_grad():
                self.eval()
                test_input, test_target = self.data_loader.load_test()
                torch_input = torch.from_numpy(test_input).to(
                    device=self.device)

                test_pred = self.predict_stack(
                    torch_input).squeeze().detach().cpu().numpy()

            imageio.mimsave(
                'outputs/test_pred.gif',
                np.concatenate([(test_input.squeeze() * 255).astype('uint8'),
                                (test_pred * 255).astype('uint8'),
                                (test_target.squeeze() * 255).astype('uint8')],
                               axis=2))
コード例 #28
0
overall_step = 0
m = find_max(root_dir + 'RGB', 4, 1, 1)
m = np.cumsum(m, axis=0)

# preprocess(root_dir,root_dir+'RGB',root_dir+'Depth',root_dir+'Albedos',root_dir+'Normals',root_dir+'GroundTruth',m,512)
dataset = AutoEncoderData(root_dir + 'RGB', root_dir + 'input',
                          root_dir + 'gt', (512, 512), m, True, 256)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
total_step = len(train_loader)

torch.cuda.empty_cache()
for epoch in range(100):
    start = timeit.default_timer()
    total_loss = 0
    total_loss_num = 0
    print('Epoch:', epoch, 'LR:', scheduler.get_lr())
    count = 0
    for i, data in enumerate(train_loader):

        input = data['image'].float().to(device)
        label = data['output'].float().to(device)
        outputs = torch.zeros_like(label)
        targets = torch.zeros_like(label)
        loss_final = 0
        ls_final = 0
        lg_final = 0
        lt_final = 0
        for j in range(7):
            input_i = input[:, j, :, :, :]
            label_i = label[:, j, :, :, :]
            output = model(input_i, j)
コード例 #29
0
    def train_model(self, epochs=100, batch_size=10, lr=0.0002, gamma=1.0):
        """Trains the model for given epochs and batch_size and saves model at each epoch
            PARAMS:
                epochs (int): number of training rounds on the data
                batch_size (int): silmultaneous images to train on
                lr (float): learning rate for Adam optimizer
                gamma (float): geometric decrease rate for LRScheduler
        """

        criterion = nn.MSELoss()
        optimizer = optim.Adam(filter(lambda p: p.requires_grad,
                                      self.parameters()),
                               lr=lr,
                               betas=(0.9, 0.999))

        scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

        losses = [[] for _ in range(5)]

        print("Training on {:d} images...".format(
            len(self.data_loader.train_images)))
        print("\r[{}/{}]{}  loss: {:.6}".format(0, epochs,
                                                progress_bar(0, epochs), 0.0),
              end=' ' * 10)

        for e in range(epochs):
            print('LR: {:.7f}'.format(scheduler.get_lr()[0]), end=' ' * 10)

            for batch_i, (train_input, train_target) in enumerate(
                    self.data_loader.load_batch(batch_size)):

                # Gets data to the model's device
                train_input = torch.from_numpy(train_input).to(
                    device=self.device)
                train_target = torch.from_numpy(train_target).to(
                    device=self.device)

                output = self.forward(train_input)
                loss = criterion(output, train_target)

                # MSELoss computed on pili pixels only
                masked_loss = ((output - train_target)**2 *
                               (train_target != 0).float()).sum()

                l1_reg = output.norm(1)

                l2_reg = torch.tensor(0).float().to(self.device)
                for param in self.parameters():
                    l2_reg += param.norm(2)

                optimizer.zero_grad()
                (loss + masked_loss * self.masked_loss_weight +
                 l2_reg * self.l2_reg_weight +
                 l1_reg * self.l1_reg_weight).backward()
                optimizer.step()

                if not batch_i % 5:
                    print(
                        "\r[{:0>2d}/{}]{}  loss: {:.6f}  w/ reg {:.6f}".format(
                            e + 1, epochs,
                            progress_bar(
                                (e * self.data_loader.n_batches) + batch_i,
                                (epochs - 1) * self.data_loader.n_batches,
                                newline=False),
                            (loss +
                             masked_loss * self.masked_loss_weight).item(),
                            (loss + masked_loss * self.masked_loss_weight +
                             l2_reg * self.l2_reg_weight +
                             l1_reg * self.l1_reg_weight).item()),
                        end='\t')

                    losses[0].append(e +
                                     (batch_i) / self.data_loader.n_batches)
                    losses[1].append((loss).item())
                    losses[2].append(
                        (masked_loss * self.masked_loss_weight).item())
                    losses[3].append((l1_reg * self.l1_reg_weight).item())
                    losses[4].append((l2_reg * self.l2_reg_weight).item())

            if e > 30:
                scheduler.step()

            if not (e + 1) % 5:
                # Generate validation images in eval() mode
                self.eval()
                with torch.no_grad():
                    val_input, val_target = self.data_loader.load_val()

                    _, axes = plt.subplots(1, 3, figsize=(60, 20))
                    axes[0].imshow(val_input.squeeze(), cmap='gray')
                    axes[0].set_title("validation input")

                    val_input = torch.from_numpy(val_input).to(
                        device=self.device)

                    axes[1].imshow(
                        (self(val_input)).detach().squeeze().cpu().numpy(),
                        cmap='gray')
                    axes[1].set_title("validation output")
                    axes[2].imshow(val_target.squeeze(), cmap='gray')
                    axes[2].set_title("validation target")
                    plt.savefig(f"outputs/fig_{e+1:d}.png")
                    plt.close()

                self.train()

                if len(losses[0]) > 50:
                    plt.figure(figsize=(20, 20))
                    plt.plot(losses[0][50:], losses[1][50:], label="Loss")
                    plt.plot(losses[0][50:],
                             losses[2][50:],
                             label="Masked Loss")
                    plt.plot(losses[0][50:], losses[3][50:], label="L1 Reg")
                    plt.plot(losses[0][50:], losses[4][50:], label="L2 Reg")
                    plt.legend()
                    plt.xlabel("Epoch")
                    plt.ylabel("Loss")
                    plt.savefig("outputs/loss.png", dpi=1000)
                    plt.close()

            if not (e + 1) % 10:
                self.save_state_dir(
                    'outputs/saved_models',
                    "{}_{}.pth".format(self.name.lower(), e + 1))

        else:
            print('')
            self.save_state_dir('outputs/saved_models',
                                "{}_final.pth".format(self.name.lower()))
            self.eval()
            with torch.no_grad():
                test_img, test_mask = self.data_loader.load_test()
                test_pred = self(torch.from_numpy(test_img).to(
                    self.device)).detach().cpu().numpy().squeeze()
                imageio.mimsave(
                    'outputs/test_pred.gif',
                    np.concatenate(
                        [(test_img.squeeze(1) * 255).astype('uint8'),
                         (test_pred * 255).astype('uint8'),
                         (test_mask.squeeze(1) * 255).astype('uint8')],
                        axis=2))
            self.train()
コード例 #30
0
ファイル: train.py プロジェクト: hustzeyu/3DfaceSiblingNet
        optimizer = torch.optim.Adam([{
            'params': model.parameters()
        }, {
            'params': metric_fc.parameters()
        }],
                                     lr=opt.lr,
                                     weight_decay=opt.weight_decay)
    scheduler = StepLR(optimizer, step_size=opt.lr_step, gamma=0.1)

    ############ train and test model ##########################
    print("*" * 100)
    print("Strat training ...")
    for epoch in range(1, opt.max_epoch + 1):
        scheduler.step()
        #learn_rate = opt.lr * 0.1**((epoch-1)//opt.lr_step)
        learn_rate = scheduler.get_lr()[0]
        print("learn_rate:%s" % learn_rate)

        model.train()
        batch_idx = 1
        for data in tqdm(trainloader):
            data_input, data_input1, label = data
            data_input = data_input.to(device)
            data_input1 = data_input1.to(device)
            label = label.to(device).long()
            if batch_idx == 1:
                tensor_board.visual_img(data_input, epoch)
            #import pdb;pdb.set_trace()
            feature = model(data_input, data_input1)
            output = metric_fc(feature, label)
            loss = criterion(output, label)
コード例 #31
0
best_loss_val = 100 
criterion = CosineMarginCrossEntropy().cuda()
exp_lr_scheduler = StepLR(optimizer, step_size=18, gamma=0.1)
for epoch in range(num_epochs):
    exp_lr_scheduler.step()
   
    
    # train for one epoch
    sample_weights = train(train_loader, model, criterion, optimizer, epoch, sample_weights, neptune_ctx)

    # evaluate on validation set
    acc1, acc5, loss_val = validate(val_loader, model, criterion)
    neptune_ctx.channel_send('val-acc1', acc1)
    neptune_ctx.channel_send('val-acc5', acc5)
    neptune_ctx.channel_send('val-loss', loss_val)
    neptune_ctx.channel_send('lr', float(exp_lr_scheduler.get_lr()[0]))
    
    logger.info(f'Epoch: {epoch} Acc1: {acc1} Acc5: {acc5} Val-Loss: {loss_val}')
    
    # remember best acc@1 and save checkpoint
    is_best = acc1 >= best_acc1     
    best_acc1 = max(acc1, best_acc1)
    save_checkpoint({
        'epoch': epoch + 1,
        'arch': 'resnet18',
        'state_dict': model.state_dict(),
        'best_acc1': best_acc1,
        }, is_best, name = name + "_acc1", filename = check_filename)
        

print(f"Best ACC: {best_acc1}")