]> wimlib.net Git - wimlib/blobdiff - src/win32_common.c
WIMBoot: Update WimOverlay.dat directly when WOF not running
[wimlib] / src / win32_common.c
index e0d39ac3e1ce8be16b62eba5e84fe070507fe509..67d208d23256c5daa347dfb0cb59cf7c65890f4b 100644 (file)
 
 #include <errno.h>
 
-#include "wimlib/win32_common.h"
+#ifdef WITH_NTDLL
+#  include <winternl.h>
+#endif
 
+#include "wimlib/win32_common.h"
 #include "wimlib/assert.h"
 #include "wimlib/error.h"
+#include "wimlib/util.h"
 
-#ifdef ENABLE_ERROR_MESSAGES
-void
-win32_error(DWORD err_code)
-{
-       wchar_t *buffer;
-       DWORD nchars;
-       nchars = FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM |
-                                   FORMAT_MESSAGE_ALLOCATE_BUFFER,
-                               NULL, err_code, 0,
-                               (wchar_t*)&buffer, 0, NULL);
-       if (nchars == 0) {
-               ERROR("Error printing error message! "
-                     "Computer will self-destruct in 3 seconds.");
-       } else {
-               ERROR("Win32 error: %ls", buffer);
-               LocalFree(buffer);
-       }
-}
-#endif /* ENABLE_ERROR_MESSAGES */
-
-int
+static int
 win32_error_to_errno(DWORD err_code)
 {
        /* This mapping is that used in Cygwin.
@@ -323,72 +307,209 @@ win32_error_to_errno(DWORD err_code)
        }
 }
 
+
+void
+set_errno_from_win32_error(DWORD err)
+{
+       errno = win32_error_to_errno(err);
+}
+
 void
 set_errno_from_GetLastError(void)
 {
-       errno = win32_error_to_errno(GetLastError());
+       set_errno_from_win32_error(GetLastError());
+}
+
+#ifdef WITH_NTDLL
+void
+set_errno_from_nt_status(NTSTATUS status)
+{
+       set_errno_from_win32_error((*func_RtlNtStatusToDosError)(status));
+}
+#endif
+
+/* Given a Windows-style path, return the number of characters of the prefix
+ * that specify the path to the root directory of a drive, or return 0 if the
+ * drive is relative (or at least on the current drive, in the case of
+ * absolute-but-not-really-absolute paths like \Windows\System32) */
+static size_t
+win32_path_drive_spec_len(const wchar_t *path)
+{
+       size_t n = 0;
+
+       if (!wcsncmp(path, L"\\\\?\\", 4)) {
+               /* \\?\-prefixed path.  Check for following drive letter and
+                * path separator. */
+               if (path[4] != L'\0' && path[5] == L':' &&
+                   is_any_path_separator(path[6]))
+                       n = 7;
+       } else {
+               /* Not a \\?\-prefixed path.  Check for an initial drive letter
+                * and path separator. */
+               if (path[0] != L'\0' && path[1] == L':' &&
+                   is_any_path_separator(path[2]))
+                       n = 3;
+       }
+       /* Include any additional path separators.*/
+       if (n > 0)
+               while (is_any_path_separator(path[n]))
+                       n++;
+       return n;
+}
+
+bool
+win32_path_is_root_of_drive(const wchar_t *path)
+{
+       size_t drive_spec_len;
+       wchar_t full_path[32768];
+       DWORD ret;
+
+       ret = GetFullPathName(path, ARRAY_LEN(full_path), full_path, NULL);
+       if (ret > 0 && ret < ARRAY_LEN(full_path))
+               path = full_path;
+
+       /* Explicit drive letter and path separator? */
+       drive_spec_len = win32_path_drive_spec_len(path);
+       if (drive_spec_len > 0 && path[drive_spec_len] == L'\0')
+               return true;
+
+       /* All path separators? */
+       for (const wchar_t *p = path; *p != L'\0'; p++)
+               if (!is_any_path_separator(*p))
+                       return false;
+       return true;
 }
 
+
 /* Given a path, which may not yet exist, get a set of flags that describe the
  * features of the volume the path is on. */
 int
-win32_get_vol_flags(const wchar_t *path, unsigned *vol_flags_ret)
+win32_get_vol_flags(const wchar_t *path, unsigned *vol_flags_ret,
+                   bool *supports_SetFileShortName_ret)
 {
        wchar_t *volume;
        BOOL bret;
        DWORD vol_flags;
+       size_t drive_spec_len;
+       wchar_t filesystem_name[MAX_PATH + 1];
 
-       if (path[0] != L'\0' && path[0] != L'\\' &&
-           path[0] != L'/' && path[1] == L':')
-       {
-               /* Path starts with a drive letter; use it. */
-               volume = alloca(4 * sizeof(wchar_t));
-               volume[0] = path[0];
-               volume[1] = path[1];
-               volume[2] = L'\\';
-               volume[3] = L'\0';
-       } else {
+       if (supports_SetFileShortName_ret)
+               *supports_SetFileShortName_ret = false;
+
+       drive_spec_len = win32_path_drive_spec_len(path);
+
+       if (drive_spec_len == 0)
+               if (path[0] != L'\0' && path[1] == L':') /* Drive-relative path? */
+                       drive_spec_len = 2;
+
+       if (drive_spec_len == 0) {
                /* Path does not start with a drive letter; use the volume of
                 * the current working directory. */
                volume = NULL;
+       } else {
+               /* Path starts with a drive letter (or \\?\ followed by a drive
+                * letter); use it. */
+               volume = alloca((drive_spec_len + 2) * sizeof(wchar_t));
+               wmemcpy(volume, path, drive_spec_len);
+               /* Add trailing backslash in case this was a drive-relative
+                * path. */
+               volume[drive_spec_len] = L'\\';
+               volume[drive_spec_len + 1] = L'\0';
        }
-       bret = GetVolumeInformationW(volume, /* lpRootPathName */
-                                    NULL,  /* lpVolumeNameBuffer */
-                                    0,     /* nVolumeNameSize */
-                                    NULL,  /* lpVolumeSerialNumber */
-                                    NULL,  /* lpMaximumComponentLength */
-                                    &vol_flags, /* lpFileSystemFlags */
-                                    NULL,  /* lpFileSystemNameBuffer */
-                                    0);    /* nFileSystemNameSize */
+       bret = GetVolumeInformation(
+                       volume,                         /* lpRootPathName */
+                       NULL,                           /* lpVolumeNameBuffer */
+                       0,                              /* nVolumeNameSize */
+                       NULL,                           /* lpVolumeSerialNumber */
+                       NULL,                           /* lpMaximumComponentLength */
+                       &vol_flags,                     /* lpFileSystemFlags */
+                       filesystem_name,                /* lpFileSystemNameBuffer */
+                       ARRAY_LEN(filesystem_name));    /* nFileSystemNameSize */
        if (!bret) {
-               DWORD err = GetLastError();
-               WARNING("Failed to get volume information for path \"%ls\"", path);
-               win32_error(err);
+               set_errno_from_GetLastError();
+               WARNING_WITH_ERRNO("Failed to get volume information for "
+                                  "path \"%ls\"", path);
                vol_flags = 0xffffffff;
+               goto out;
        }
 
+       if (wcsstr(filesystem_name, L"NTFS")) {
+               /* FILE_SUPPORTS_HARD_LINKS is only supported on Windows 7 and later.
+                * Force it on anyway if filesystem is NTFS.  */
+               vol_flags |= FILE_SUPPORTS_HARD_LINKS;
+
+               if (supports_SetFileShortName_ret)
+                       *supports_SetFileShortName_ret = true;
+       }
+
+out:
        DEBUG("using vol_flags = %x", vol_flags);
        *vol_flags_ret = vol_flags;
        return 0;
 }
 
-HANDLE
-win32_open_existing_file(const wchar_t *path, DWORD dwDesiredAccess)
+static bool
+win32_modify_privilege(const wchar_t *privilege, bool enable)
 {
-       return CreateFileW(path,
-                          dwDesiredAccess,
-                          FILE_SHARE_READ,
-                          NULL, /* lpSecurityAttributes */
-                          OPEN_EXISTING,
-                          FILE_FLAG_BACKUP_SEMANTICS |
-                              FILE_FLAG_OPEN_REPARSE_POINT,
-                          NULL /* hTemplateFile */);
+       HANDLE hToken;
+       LUID luid;
+       TOKEN_PRIVILEGES newState;
+       bool ret = FALSE;
+
+       if (!OpenProcessToken(GetCurrentProcess(),
+                             TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY,
+                             &hToken))
+               goto out;
+
+       if (!LookupPrivilegeValue(NULL, privilege, &luid))
+               goto out_close_handle;
+
+       newState.PrivilegeCount = 1;
+       newState.Privileges[0].Luid = luid;
+       newState.Privileges[0].Attributes = (enable ? SE_PRIVILEGE_ENABLED : 0);
+       SetLastError(ERROR_SUCCESS);
+       ret = AdjustTokenPrivileges(hToken, FALSE, &newState, 0, NULL, NULL);
+       if (ret && GetLastError() == ERROR_NOT_ALL_ASSIGNED)
+               ret = FALSE;
+out_close_handle:
+       CloseHandle(hToken);
+out:
+       return ret;
+}
+
+static bool
+win32_modify_capture_privileges(bool enable)
+{
+       return win32_modify_privilege(SE_BACKUP_NAME, enable)
+           && win32_modify_privilege(SE_SECURITY_NAME, enable);
+}
+
+static bool
+win32_modify_apply_privileges(bool enable)
+{
+       return win32_modify_privilege(SE_RESTORE_NAME, enable)
+           && win32_modify_privilege(SE_SECURITY_NAME, enable)
+           && win32_modify_privilege(SE_TAKE_OWNERSHIP_NAME, enable);
+}
+
+static void
+win32_release_capture_and_apply_privileges(void)
+{
+       win32_modify_capture_privileges(false);
+       win32_modify_apply_privileges(false);
 }
 
 HANDLE
-win32_open_file_data_only(const wchar_t *path)
+win32_open_existing_file(const wchar_t *path, DWORD dwDesiredAccess)
 {
-       return win32_open_existing_file(path, FILE_READ_DATA);
+       return CreateFile(path,
+                         dwDesiredAccess,
+                         FILE_SHARE_READ,
+                         NULL, /* lpSecurityAttributes */
+                         OPEN_EXISTING,
+                         FILE_FLAG_BACKUP_SEMANTICS |
+                               FILE_FLAG_OPEN_REPARSE_POINT,
+                         NULL /* hTemplateFile */);
 }
 
 /* Pointers to functions that are not available on all targetted versions of
@@ -405,12 +526,60 @@ HANDLE (WINAPI *win32func_FindFirstStreamW)(LPCWSTR lpFileName,
 BOOL (WINAPI *win32func_FindNextStreamW)(HANDLE hFindStream,
                                         LPVOID lpFindStreamData) = NULL;
 
+/* Vista and later */
+BOOL (WINAPI *win32func_CreateSymbolicLinkW)(const wchar_t *lpSymlinkFileName,
+                                            const wchar_t *lpTargetFileName,
+                                            DWORD dwFlags) = NULL;
+
+#ifdef WITH_NTDLL
+
+DWORD (WINAPI *func_RtlNtStatusToDosError)(NTSTATUS status);
+
+NTSTATUS (WINAPI *func_NtQueryInformationFile)(HANDLE FileHandle,
+                                              PIO_STATUS_BLOCK IoStatusBlock,
+                                              PVOID FileInformation,
+                                              ULONG Length,
+                                              FILE_INFORMATION_CLASS FileInformationClass);
+
+NTSTATUS (WINAPI *func_NtQuerySecurityObject)(HANDLE handle,
+                                             SECURITY_INFORMATION SecurityInformation,
+                                             PSECURITY_DESCRIPTOR SecurityDescriptor,
+                                             ULONG Length,
+                                             PULONG LengthNeeded);
+
+NTSTATUS (WINAPI *func_NtQueryDirectoryFile) (HANDLE FileHandle,
+                                             HANDLE Event,
+                                             PIO_APC_ROUTINE ApcRoutine,
+                                             PVOID ApcContext,
+                                             PIO_STATUS_BLOCK IoStatusBlock,
+                                             PVOID FileInformation,
+                                             ULONG Length,
+                                             FILE_INFORMATION_CLASS FileInformationClass,
+                                             BOOLEAN ReturnSingleEntry,
+                                             PUNICODE_STRING FileName,
+                                             BOOLEAN RestartScan);
+
+NTSTATUS (WINAPI *func_NtSetSecurityObject)(HANDLE Handle,
+                                           SECURITY_INFORMATION SecurityInformation,
+                                           PSECURITY_DESCRIPTOR SecurityDescriptor);
+
+NTSTATUS (WINAPI *func_RtlCreateSystemVolumeInformationFolder)
+               (PCUNICODE_STRING VolumeRootPath);
+
+#endif /* WITH_NTDLL */
+
 static OSVERSIONINFO windows_version_info = {
        .dwOSVersionInfoSize = sizeof(OSVERSIONINFO),
 };
 
 static HMODULE hKernel32 = NULL;
 
+#ifdef WITH_NTDLL
+static HMODULE hNtdll = NULL;
+#endif
+
+static bool acquired_privileges = false;
+
 bool
 windows_version_is_at_least(unsigned major, unsigned minor)
 {
@@ -419,22 +588,28 @@ windows_version_is_at_least(unsigned major, unsigned minor)
                 windows_version_info.dwMinorVersion >= minor);
 }
 
-/* Try to dynamically load some functions */
-void
-win32_global_init(void)
+/* One-time initialization for Windows capture/apply code.  */
+int
+win32_global_init(int init_flags)
 {
-       DWORD err;
-
-       if (hKernel32 == NULL) {
-               DEBUG("Loading Kernel32.dll");
-               hKernel32 = LoadLibraryW(L"Kernel32.dll");
-               if (hKernel32 == NULL) {
-                       err = GetLastError();
-                       WARNING("Can't load Kernel32.dll");
-                       win32_error(err);
-               }
+       /* Try to acquire useful privileges.  */
+       if (!(init_flags & WIMLIB_INIT_FLAG_DONT_ACQUIRE_PRIVILEGES)) {
+               if (!win32_modify_capture_privileges(true))
+                       if (init_flags & WIMLIB_INIT_FLAG_STRICT_CAPTURE_PRIVILEGES)
+                               goto insufficient_privileges;
+               if (!win32_modify_apply_privileges(true))
+                       if (init_flags & WIMLIB_INIT_FLAG_STRICT_APPLY_PRIVILEGES)
+                               goto insufficient_privileges;
+               acquired_privileges = true;
        }
 
+       /* Get Windows version information.  */
+       GetVersionEx(&windows_version_info);
+
+       /* Try to dynamically load some functions.  */
+       if (hKernel32 == NULL)
+               hKernel32 = LoadLibrary(L"Kernel32.dll");
+
        if (hKernel32) {
                win32func_FindFirstStreamW = (void*)GetProcAddress(hKernel32,
                                                                   "FindFirstStreamW");
@@ -444,19 +619,69 @@ win32_global_init(void)
                        if (!win32func_FindNextStreamW)
                                win32func_FindFirstStreamW = NULL;
                }
+               win32func_CreateSymbolicLinkW = (void*)GetProcAddress(hKernel32,
+                                                                     "CreateSymbolicLinkW");
        }
 
-       GetVersionEx(&windows_version_info);
+#ifdef WITH_NTDLL
+       if (hNtdll == NULL)
+               hNtdll = LoadLibrary(L"ntdll.dll");
+
+       if (hNtdll) {
+               func_RtlNtStatusToDosError  =
+                       (void*)GetProcAddress(hNtdll, "RtlNtStatusToDosError");
+               if (func_RtlNtStatusToDosError) {
+
+                       func_NtQuerySecurityObject  =
+                               (void*)GetProcAddress(hNtdll, "NtQuerySecurityObject");
+
+                       func_NtQueryDirectoryFile   =
+                               (void*)GetProcAddress(hNtdll, "NtQueryDirectoryFile");
+
+                       func_NtQueryInformationFile =
+                               (void*)GetProcAddress(hNtdll, "NtQueryInformationFile");
+
+                       func_NtSetSecurityObject    =
+                               (void*)GetProcAddress(hNtdll, "NtSetSecurityObject");
+                       func_RtlCreateSystemVolumeInformationFolder =
+                               (void*)GetProcAddress(hNtdll, "RtlCreateSystemVolumeInformationFolder");
+               }
+       }
+
+       DEBUG("FindFirstStreamW       @ %p", win32func_FindFirstStreamW);
+       DEBUG("FindNextStreamW        @ %p", win32func_FindNextStreamW);
+       DEBUG("CreateSymbolicLinkW    @ %p", win32func_CreateSymbolicLinkW);
+       DEBUG("RtlNtStatusToDosError  @ %p", func_RtlNtStatusToDosError);
+       DEBUG("NtQuerySecurityObject  @ %p", func_NtQuerySecurityObject);
+       DEBUG("NtQueryDirectoryFile   @ %p", func_NtQueryDirectoryFile);
+       DEBUG("NtQueryInformationFile @ %p", func_NtQueryInformationFile);
+       DEBUG("NtSetSecurityObject    @ %p", func_NtSetSecurityObject);
+       DEBUG("RtlCreateSystemVolumeInformationFolder    @ %p",
+             func_RtlCreateSystemVolumeInformationFolder);
+#endif
+
+       return 0;
+
+insufficient_privileges:
+       win32_release_capture_and_apply_privileges();
+       return WIMLIB_ERR_INSUFFICIENT_PRIVILEGES;
 }
 
 void
 win32_global_cleanup(void)
 {
+       if (acquired_privileges)
+               win32_release_capture_and_apply_privileges();
        if (hKernel32 != NULL) {
-               DEBUG("Closing Kernel32.dll");
                FreeLibrary(hKernel32);
                hKernel32 = NULL;
        }
+#ifdef WITH_NTDLL
+       if (hNtdll != NULL) {
+               FreeLibrary(hNtdll);
+               hNtdll = NULL;
+       }
+#endif
 }
 
 #endif /* __WIN32__ */