예제 #1
0
def worker_fn(rank, world_size):
    setup(rank, world_size)

    weights_filename = "weights.pt"
    batch_size = 512
    epochs = 240
    warmup_epochs = 8
    use_mixed_precision = True

    batch_size = batch_size // world_size #batch size per worker

    #Data
    all_data = os.listdir(datapath_preprocessed)
    train_filenames = [p for p in all_data if re.match(r'^PGM_' + re.escape(dataset_name) + r'_train_(\d+)\.npz$', p) is not None]
    val_filenames = [p for p in all_data if re.match(r'^PGM_' + re.escape(dataset_name) + r'_val_(\d+)\.npz$', p) is not None]
    train_dataset = PgmDataset(train_filenames)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, pin_memory=False, sampler=train_sampler)#shuffle is done by the sampler
    val_dataloader = DataLoader(PgmDataset(val_filenames), batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=False)

    #Model
    device_ids = [rank]

    model = WReN(2).to(device_ids[0])#3-layer MLRN

    if weights_filename is not None and os.path.isfile("./" + weights_filename):
        model.load_state_dict(torch.load(weights_filename, map_location='cpu'))
        print('Weights loaded')
        cold_start = False
    else:
        print('No weights found')
        cold_start = True

    #Loss and optimizer
    final_lr = 2e-3

    def add_module_params_with_decay(module, weight_decay, param_groups):#adds parameters with decay unless they are bias parameters, which shouldn't receive decay
        group_with_decay = []
        group_without_decay = []
        for name, param in module.named_parameters():
            if not param.requires_grad: continue
            if name == 'bias' or name.endswith('bias'):
                group_without_decay.append(param)
            else:
                group_with_decay.append(param)
        param_groups.append({"params": group_with_decay, "weight_decay": weight_decay})
        param_groups.append({"params": group_without_decay})

    optimizer_param_groups = [
    ]

    add_module_params_with_decay(model.conv, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.post_cnn_linear, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.g, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.h, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.f, 2e-1, optimizer_param_groups)
    add_module_params_with_decay(model.f_final, 2e-1, optimizer_param_groups)

    optimizer = Lamb(optimizer_param_groups, lr=final_lr)

    base_model = model
    if use_mixed_precision:
        model, optimizer = amp.initialize(model, optimizer, opt_level="O1") #Mixed Precision

    lossFunc = torch.nn.CrossEntropyLoss()
    softmax = torch.nn.Softmax(dim=1)

    #Parallel distributed model
    device = device_ids[0]
    torch.cuda.set_device(device)
    parallel_model = torch.nn.parallel.DistributedDataParallel(model, device_ids)

    if rank == 0:
        #accuracy logging
        sess = tf.Session()
        train_acc_placeholder = tf.placeholder(tf.float32, shape=())
        train_acc_summary = tf.summary.scalar('training_acc', train_acc_placeholder)
        val_acc_placeholder = tf.placeholder(tf.float32, shape=())
        val_acc_summary = tf.summary.scalar('validation_acc', val_acc_placeholder)
        writer = tf.summary.FileWriter("log", sess.graph)

    #training loop
    acc = []
    global_step = 0
    for epoch in range(epochs): 
        train_sampler.set_epoch(epoch) 

        # Validation
        val_acc = []
        parallel_model.eval()
        with torch.no_grad():
            for i, (local_batch, local_labels) in enumerate(val_dataloader):
                local_batch, targets = local_batch.to(device), local_labels.to(device)

                #answer = model(local_batch.type(torch.float32))
                answer, _ = parallel_model(local_batch.type(torch.float32))

                #Calc accuracy
                answerSoftmax = softmax(answer)
                maxIndex = answerSoftmax.argmax(dim=1)

                correct = maxIndex.eq(targets)
                accuracy = correct.type(dtype=torch.float16).mean(dim=0)
                val_acc.append(accuracy)

                if i % 50 == 0 and rank == 0:
                    print("batch " + str(i))

        total_val_acc = sum(val_acc) / len(val_acc)
        print('Validation accuracy: ' + str(total_val_acc.item()))
        if rank == 0:
            summary = sess.run(val_acc_summary, feed_dict={val_acc_placeholder: total_val_acc.item()})
            writer.add_summary(summary, global_step=global_step)

        # Training
        parallel_model.train()
        for i, (local_batch, local_labels) in enumerate(train_dataloader):
            global_step = global_step + 1

            if cold_start and epoch < warmup_epochs:#linear scaling of the lr for warmup during the first few epochs
                lr = final_lr * global_step / (warmup_epochs*len(train_dataset) / (batch_size * world_size))
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr

            local_batch, targets = local_batch.to(device_ids[0]), local_labels.to(device_ids[0])

            optimizer.zero_grad()
            answer, activation_loss = parallel_model(local_batch.type(torch.float32))

            loss = lossFunc(answer, targets) + activation_loss * 2e-3

            #Calc accuracy
            answerSoftmax = softmax(answer)
            maxIndex = answerSoftmax.argmax(dim=1)

            correct = maxIndex.eq(targets)
            accuracy = correct.type(dtype=torch.float16).mean(dim=0)
            acc.append(accuracy)
            
            #Training step
            if use_mixed_precision:
                with amp.scale_loss(loss, optimizer) as scaled_loss: #Mixed precision
                    scaled_loss.backward()
            else:
                loss.backward()

            grad_norm = torch.nn.utils.clip_grad_norm_(parallel_model.parameters(), 1e1)

            optimizer.step()

            if i % 50 == 0 and rank == 0:
                print("epoch " + str(epoch) + " batch " + str(i))
                print("loss", loss)
                print("activation loss", activation_loss)
                print(grad_norm)

            #logging and saving weights
            if i % 1000 == 999:
                trainAcc = sum(acc) / len(acc)
                acc = []
                print('Training accuracy: ' + str(trainAcc.item()))
                if rank == 0:
                    if weights_filename is not None:
                        torch.save(base_model.state_dict(), weights_filename)
                        print('Weights saved')

                    summary = sess.run(train_acc_summary, feed_dict={train_acc_placeholder: trainAcc.item()})
                    writer.add_summary(summary, global_step=global_step)  

        if cold_start and weights_filename is not None and epoch % 10 == 0 and rank == 0:
            torch.save(base_model.state_dict(), weights_filename + "_cp" + str(epoch))
            print('Checkpoint saved')


    cleanup()
예제 #2
0
    acc_train = 0
    acc_test = 0

    model.train()
    for img, labels in dataloader_train:
        #img, labels = batch
        img, labels = img.to(device), labels.to(device)
        #print(labels[0])
        #labelsmat = F.one_hot(labels, num_classes=10).to(device)
        output = model(img)
        #loss = torch.sum((output-labelsmat)**2)
        loss = F.cross_entropy(output, labels)
        acc_train += torch.sum(torch.argmax(output,
                                            dim=-1) == labels)  #.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.detach()

    # Testing
    model.eval()
    for img, labels in dataloader_test:
        #img, labels = batch
        img, labels = img.to(device), labels.to(device)
        #print(labels[0])
        #labelsmat = F.one_hot(labels, num_classes=10).to(device)
        output = model(img)
        acc_test += torch.sum(torch.argmax(output, dim=-1) == labels)