pin_memory=True,
                                          worker_init_fn=_init_fn)

##############
# LOAD MODEL #
##############

if args.restore is not None:
    print("Restoring model from:", args.restore)
    checkpoint = torch.load(
        args.restore, map_location='cpu' if device.type == 'cpu' else None)
    dims = checkpoint['model_dims']
    hyperparams = checkpoint['model_hyperparams']
    trainer_params = checkpoint['train_params']
    model = autoregressive_model.AutoregressiveFR(dims=dims,
                                                  hyperparams=hyperparams,
                                                  dropout_p=args.dropout_p)
else:
    checkpoint = args.restore
    trainer_params = None
    model = autoregressive_model.AutoregressiveFR(channels=args.channels,
                                                  dropout_p=args.dropout_p)
model.to(device)

################
# RUN TRAINING #
################

trainer = autoregressive_train.AutoregressiveTrainer(
    model=model,
    data_loader=loader,
Beispiel #2
0
        self.set_parameter(layer.weight_g, layer_name + 'g')
        self.set_parameter(layer.weight_v, layer_name + 'W', permute=(3, 2, 0, 1))  # HWIO to OIHW

    def load_layer_norm(self, layer, layer_name):
        self.set_parameter(layer.bias, layer_name + 'b')
        self.set_parameter(layer.weight, layer_name + 'g')

    def set_parameter(self, parameter, name='', permute=()):
        new_data = self.reader.get_tensor(name)
        new_data = torch.as_tensor(new_data, dtype=parameter.dtype, device=parameter.device)
        new_data.requires_grad_(True)

        if permute:
            new_data = new_data.permute(*permute).contiguous()

        if new_data.shape != parameter.shape:
            raise ConversionError('mismatched shapes: {} to {} at {}'.format(
                tuple(new_data.shape), tuple(parameter.shape), name))

        parameter.data = new_data
        self.unused_keys.remove(name)


if __name__ == '__main__':
    import sys
    import autoregressive_model
    model_test = autoregressive_model.AutoregressiveFR()
    reader = TFReader(sys.argv[1])
    reader.load_autoregressive_fr(model_test)
    print([key for key in reader.unused_keys if not key.startswith("Backprop")])